Compare commits

...

455 Commits

Author SHA1 Message Date
Wing Lian
0026fcc3df remove torch install for now
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-01 08:15:22 -07:00
Wing Lian
b448c77148 address pr feedback 2023-08-29 22:45:22 -07:00
Wing Lian
c820d04669 gptq doesn't play well with sample packing 2023-08-29 12:15:31 -07:00
Wing Lian
588cd65a64 fix setup.py to use extra index url
install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd
2023-08-29 12:02:51 -07:00
Wing Lian
caa80e891d don't need explicit peft install for tests 2023-08-29 12:02:51 -07:00
Wing Lian
ac37753aa2 remove old gptq docker 2023-08-29 12:02:50 -07:00
Wing Lian
a29560004b more tweaks and add yml 2023-08-29 12:02:00 -07:00
Wing Lian
1deb767fe8 auto gptq support 2023-08-29 11:31:14 -07:00
Wing Lian
548787daae customizable ascii art (#506) 2023-08-29 10:13:42 -07:00
Wing Lian
5ac3392075 support for datasets with multiple names (#480)
* support for datasets with multiple names

* update docs
2023-08-29 06:18:17 -07:00
Aman Gupta Karmani
e356b297cb remove --force-reinstall from Dockerfile to ensure correct pytorch version (#492) 2023-08-29 06:17:51 -07:00
NanoCode012
48c56470d0 Fix(doc): Clarify no amp to full yaml docs (#496) 2023-08-29 06:17:37 -07:00
Maxime
36b2e1cfee tweak: use default config file when only one file is present (#501) 2023-08-29 06:17:10 -07:00
Wing Lian
125cccb786 Refactor train cfg cli (#499)
* wip to cleanup cfg cli options

* fix launcher

* fix cli args
2023-08-29 05:37:53 -07:00
Aman Karmani
fd55bc87e2 use math.ceil instead of round /cc #498 2023-08-29 01:03:41 +00:00
Birch-san
8e197f6fb4 pad_to_worst_case_seq_len boolean, for testing memory limits (#498)
* pad_to_worst_case_seq_len boolean, for testing memory limits

* remove collator_pad_to_longest option since it does nothing

see docs: https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding.padding

True and "longest" mean the same thing

* rename to `pad_to_sequence_len, and ensure 64 alignment

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-08-28 18:47:16 -04:00
Aman Karmani
267b7b24e5 simplify linear layer locator 2023-08-28 09:45:16 -04:00
Wing Lian
98bf76e236 fsdp requires params be the same type too (#493) 2023-08-28 04:33:50 -04:00
NanoCode012
4c37bd0b54 Fix(tokenizer): Make sure to add pad for CodeLlamaTokenizer (#489) 2023-08-28 09:39:10 +09:00
Aman Gupta Karmani
f144e98a32 Merge pull request #485 from maximegmd/patch-4
fix: finetune model inference needs the dtype fix to work with flash-attn
2023-08-27 16:27:47 -04:00
Aman Karmani
3a011ea1ef fix condition and add logging 2023-08-27 20:09:26 +00:00
Aman Karmani
1f613e5aa7 Merge branch 'main' into patch-4 2023-08-27 19:57:34 +00:00
Aman Karmani
f319b0bc67 rename var and reformat 2023-08-27 19:55:11 +00:00
Maxime
7fd662dd89 Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
2023-08-27 21:01:43 +02:00
Maxime
9e699683d7 Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
2023-08-27 21:01:37 +02:00
mhenrichsen
35130711d6 Feat(cfg): Add code-llama configs for all sizes (#479)
* configs for all sizes

* update tokenizer type

---------

Co-authored-by: mhenrichsen <some_email@hey.com>
2023-08-27 10:20:17 +09:00
mhenrichsen
3fc9006298 Feat(deepspeed): Add zero2 config (#476)
* zero2 config

* config added

* linting

---------

Co-authored-by: mhenrichsen <some_email@hey.com>
2023-08-27 10:10:33 +09:00
NanoCode012
ad8be435ad Feat(doc): Update eval_steps doc (#487) 2023-08-27 10:09:09 +09:00
Charles O. Goddard
fe4d6baf92 Add example Llama 2 ReLoRA config (#471)
* Add example Llama 2 ReLoRA config

* Use adamw_bnb_8bit in example relora config
2023-08-27 10:08:34 +09:00
Aman Gupta Karmani
f31301063d Merge pull request #486 from OpenAccess-AI-Collective/adam-bnb-simpler
let transformers handle adamw_bnb_8bit
2023-08-26 20:44:19 -04:00
Aman Karmani
868530c39c let transformers handle adamw_bnb_8bit 2023-08-26 21:40:12 +00:00
Maxime
d03887fad5 ignore: address pr review 2023-08-26 22:45:45 +02:00
Maxime
17605b85d8 fix: inference did not move the model to the correct device (#483) 2023-08-26 16:40:56 -04:00
Maxime
a184549e4c ignore: linter 2023-08-26 22:36:14 +02:00
Maxime
f311df9462 fix: finetune model inference needs the dtype fix to work with flash-attn 2023-08-26 22:34:11 +02:00
Maxime
c500d02517 Fix missing 'packaging' wheel (#482) 2023-08-26 12:02:15 -04:00
Wing Lian
31f3e71764 fix checkpints on multigpu (#481) 2023-08-26 12:00:03 -04:00
Aman Gupta Karmani
56c4a94caf Merge pull request #484 from OpenAccess-AI-Collective/reqs
allow newer deps in requirements.txt
2023-08-26 11:13:41 -04:00
Aman Karmani
c29117a0d7 allow newer deps 2023-08-26 15:06:05 +00:00
Wing Lian
0b7ba57ec4 fix types w lora (#478) 2023-08-25 02:03:24 -04:00
NanoCode012
71bd06243c Fix(tokenizer): Fix condition to add pad token (#477)
* Fix(tokenizer): Fix condition to add pad token

* chore: fix lint
2023-08-25 14:30:50 +09:00
Wing Lian
cb9797ef5a improve llama pad token handling (#475)
* improve llama pad token handling

* tweak logic to not clobber
2023-08-24 13:20:35 -04:00
Charles O. Goddard
bde3c5a478 ReLoRA implementation (with quantization) (#322)
* Experimental ReLoRA (+qlora) implementation

* Add CPU offload

* Remove local config

* Fix saving logic

* Remove redundant assert

* Fix logic errors

* Move ReLoRA into its own trainer class with a method override to create the proper scheduler

* Formatting & typing fixes

* Use safe_serialization

* Don't allow fsdp/deepspeed with ReLoRA

* Fix cpu-offload logic, enable multi gpu

* Document parameters and add comment

* Fix merge issue

* Smooth over some sharp edges

* Implement resume from checkpoint for relora

* Address review comments

* Fix saving logic

* Add necessary metadata to safetensors

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-23 23:07:18 -04:00
NanoCode012
55c23c7bcb Fix(doc): Clarify config (#466) 2023-08-23 11:56:01 -04:00
Wing Lian
c69faee7a7 workaround so training doesn't hang when packed dataloader batches aren't even (#461)
* workaround so training doesn't hang when packed dataloader batches aren't even

* don't bother labeling anything in the no-op data
2023-08-23 10:39:11 -04:00
Wing Lian
d5dcf9c350 fix test fixture b/c hf trainer tokenization changed (#464) 2023-08-23 04:04:49 -04:00
TearGosling
f4746507f6 feat: add Metharme prompt strategy (#446)
* Add Metharme tokenizing strategy

This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length.
I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now.

* Redo Metharme tokenizing strategy

lol

* fix: oops

* Rearrange a conditional

* chore: reformat code in accordance with linter

* chore: Make lint not freak out

* chore: fix lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-08-22 11:21:45 +09:00
Wing Lian
96deb6bd67 recast loralayer, norm, lmhead + embed token weights per original qlora (#393)
* recast loralayer, norm, lmhead + embed token weights per original qlora

* try again for the fix

* refactor torch dtype picking

* linter fixes

* missing import for LoraLayer

* fix install for tests now that peft is involved
2023-08-21 18:41:12 -04:00
Wing Lian
50682a3c06 always drop samples that are too long (#452) 2023-08-21 16:43:33 -04:00
Wing Lian
5a1985ba24 set env var for FSDP layer to wrap (#453) 2023-08-21 16:43:22 -04:00
Aman Gupta Karmani
5e9c6afa10 Merge pull request #451 from OpenAccess-AI-Collective/eval-is-causal
is_causal fix for evals?
2023-08-21 10:43:46 -07:00
Aman Karmani
a213d9972a fix eval regression caused in 13f7efaf74 2023-08-21 10:40:06 -07:00
Wing Lian
fbf49a4770 is_causal fix for evals? 2023-08-21 10:36:26 -04:00
Wing Lian
58cf7e7fed add missing positional arg (#450) 2023-08-21 04:10:19 -04:00
NanoCode012
04a42b6db1 feat(docs): improve user customized prompts (#443)
* feat(docs): improve user customized prompts

* feat(doc): add custom pretokenized instructions

* chore: clean old data folder

* chore: add new line
2023-08-20 23:59:43 -04:00
NanoCode012
919f4cac90 feat(doc): add pillow to lambda instructions (#445) 2023-08-20 23:59:23 -04:00
Wing Lian
ee262818ef fix evals (#447) 2023-08-20 23:39:42 -04:00
Wing Lian
9d629d8bff gracefully handle empty input (#442) 2023-08-20 09:18:18 -04:00
Wing Lian
d2e7f27240 support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files

* fix user defined dataset types

* fix for system prompts

* fix tests

* fix checks for parquet and arrow

* aha moment that d.data_files isn't used

* add documentation for ds_type to add support for parquet and arrow
2023-08-20 09:17:49 -04:00
Philpax
d21318dfb9 docs(readme): add cd axolotl (#440) 2023-08-19 19:14:05 -04:00
Wing Lian
f733d0f31e disable eval using multipack for now (#437) 2023-08-19 10:35:04 -04:00
Wing Lian
008505c8ae fix comma, not a tuple (#436) 2023-08-19 00:57:40 -04:00
Wing Lian
b3f5e00ff5 use save_strategy from config if available (#434)
* use save_strategy from config if available

* update docs for save_strategy
2023-08-18 20:28:23 -04:00
Wing Lian
5247c5004e set env for FSDP offload params (#433) 2023-08-18 20:28:09 -04:00
mhenrichsen
cf6654769a flash attn pip install (#426)
* flash attn pip

* add packaging

* add packaging to apt get

* install flash attn in dockerfile

* remove unused whls

* add wheel

* clean up pr

fix packaging requirement for ci
upgrade pip for ci
skip build isolation for requiremnents to get flash-attn working
install flash-attn seperately

* install wheel for ci

* no flash-attn for basic cicd

* install flash-attn as pip extras

---------

Co-authored-by: Ubuntu <mgh@mgh-vm.wsyvwcia0jxedeyrchqg425tpb.ax.internal.cloudapp.net>
Co-authored-by: mhenrichsen <some_email@hey.com>
Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-18 19:00:27 -04:00
Aman Gupta Karmani
06edf175ac standardize attn hijack patches (#381)
* split sdp attn into its own patch

* sync xformers patch to follow shared format and be diffable

* update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests

* speed up flash-attn inference

* fix patch to check position ids and don't use multipack for evals

* copy LlamaModel.forward and LlamaDecoderLayer.forward into monkeypatch

* update forwards so we only calculate cu_seqlens once

* enable eval dataloader using multipack again

* fix the patch to work properly and work with FSDP

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-18 12:54:16 -04:00
mhenrichsen
0a228479b3 adds color (#425)
* adds color

* chore: lint

* fix for colorama

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-18 10:59:43 -04:00
Wing Lian
82e111aba9 remove extra accelearate in requirements (#430) 2023-08-18 10:56:14 -04:00
Wing Lian
8cace80175 fix fixture for new tokenizer handling in transformers (#428) 2023-08-17 17:01:52 -04:00
Wing Lian
1b7e8604bb fix orca prompts (#422) 2023-08-16 11:21:03 -04:00
NanoCode012
3d1f203b62 Fix(docs): Remove gptq+lora and fix xformer compat list (#423) 2023-08-16 13:56:48 +09:00
Wing Lian
d3d6fd6ae6 just resort to tags ans use main-latest (#424) 2023-08-16 00:39:57 -04:00
NanoCode012
b7449a997f Fix(template): Inform to place stack trace to Issue (#417)
* Fix(template): Inform to place stack trace to Issue

* Update following suggestions

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-16 11:55:48 +09:00
Wing Lian
5f80b3560b use inputs for image rather than outputs for docker metadata (#420) 2023-08-15 18:26:59 -04:00
Wing Lian
24959091d7 hopefully improve the README (#419)
* hopefully improve the README

* exitcode -9 help

* table of contents

* formatting
2023-08-15 15:30:53 -04:00
Wing Lian
7af816699e tag with latest as well for axolotl-runpod (#418)
* tag with latest as well for axolotl-runpod

* no dev branch for now
2023-08-15 15:30:41 -04:00
mhenrichsen
f806e86a6e Merge pull request #413 from mhenrichsen/chore/update-deepseed-config
update path to align with fsdp example
2023-08-15 20:08:23 +02:00
NanoCode012
2b990eb628 Feat(doc): Add lr_quadratic_warmup to readme (#412) 2023-08-16 02:55:48 +09:00
mhenrichsen
bd8cab49c9 update path to align with fsdp example 2023-08-15 19:51:58 +02:00
NanoCode012
c01015f33f Fix(config): Update handling of deepspeed config (#404)
* Fix(config): Update handling of deepspeed config

* feat: auto set deepspeed env if deepspeed passed

* fix: update new deepspeed instructions
2023-08-16 01:22:43 +09:00
NanoCode012
72fe3f8e3d Fix(docs): Update flash attn requirements (#409) 2023-08-15 22:40:52 +09:00
Wing Lian
47961fdb8b update docs for tokenizer_legacy (#401)
* update docs for tokenizer_legacy

* add default info
2023-08-15 22:34:42 +09:00
NanoCode012
7ad37cb6d7 Fix(template): Remove iPhone/android from Issue template (#407) 2023-08-15 22:32:51 +09:00
Wing Lian
29241cf1e4 Ax art (#405)
* axolotl text art :D

* only print art on rank0

* lint and pr feedback
2023-08-15 08:34:30 -04:00
lightningRalf
31db0ecce4 add templates, CoC and contributing guide (#126)
* add templates, CoC and contributing guide

* Update .github/SECURITY.md

correct responsible person

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update bug-report.yaml

axolotl version switch with axolotl branch-commit

* update CONTRIBUTING doc

* update reporting link

* linter fixes

* chore: fix linter

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-08-15 07:41:05 -04:00
Wing Lian
da10af03e9 fix eval steps and strategy (#403) 2023-08-15 07:28:50 -04:00
Wing Lian
85cf4f8e2c better handling of empty input ids when tokenizing (#395)
* better handling of empty input ids when tokenizing

* Add warning if tokenizer resulted in empty result

* fix len comparison for linter
2023-08-15 01:09:59 -04:00
Aman Karmani
2e22404d2d add utils.data.prepare_dataset 2023-08-14 21:28:29 -07:00
NanoCode012
be294fd605 Feat(doc): Add how to save by epochs (#396) 2023-08-15 13:24:25 +09:00
Wing Lian
fc2d6be96d use context manager to run things on rank0 before others (#397) 2023-08-15 00:10:47 -04:00
Wing Lian
1687be6a35 don't use mask expansion for inference (#392) 2023-08-14 20:52:54 -04:00
NanoCode012
41ecb451c2 Feat(doc): Add max_steps to readme (#389) 2023-08-15 00:34:22 +09:00
Gabriel Puliatti
3c2ad00d07 Feat(config): add max steps (#387) 2023-08-14 11:19:29 -04:00
florian peyron
5d48a10548 Added "epoch" evaluation_strategy (#388) 2023-08-14 10:59:23 -04:00
NanoCode012
73a0b6ead5 Feat(config): Add hub_strategy (#386) 2023-08-14 07:12:55 -04:00
florian peyron
63fdb5a7fb Error msg for sharegpt if conv has less than 2 msg (#379) 2023-08-14 17:40:40 +09:00
mhenrichsen
fdffef5940 new llama-2 default settings (#370)
* new default settings

* fix whitespace

* rm max packed sequence length

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
2023-08-14 17:39:09 +09:00
Wing Lian
919246fbc1 don't pass rope_scaling kwarg if it's None (#383) 2023-08-13 18:57:38 -04:00
Wing Lian
ffac902c1b bump flash-attn to 2.0.4 for the base docker image (#382) 2023-08-13 17:55:04 -04:00
Charles Goddard
15f6e57eaa Fix crash when running without CUDA 2023-08-13 13:36:40 -07:00
NanoCode012
729c299256 Feat(doc): Improve sharegpt doc (#378)
* Feat(doc): Improve sharegpt doc

* Fix typo
2023-08-14 00:36:00 +09:00
Wing Lian
86a91e260b save tokenizer before training starts (#380) 2023-08-13 11:28:58 -04:00
Aman Gupta Karmani
094fc2c6e6 try to detect accelerate and only use device_map=None in that case (#373) 2023-08-13 00:32:07 -04:00
Wing Lian
2dafa730ef Create FUNDING.yml 2023-08-13 00:30:34 -04:00
Wing Lian
343ac84e5a fix check for flash attn branching (#377) 2023-08-12 22:48:08 -04:00
Aman Karmani
0c967279ce remove unnecessary local variable 2023-08-13 01:58:39 +00:00
Aman Karmani
efb3b2c95e simplify load_tokenizer 2023-08-12 18:55:06 -07:00
Aman Karmani
7b55fe6419 improve GPU logging to break out pytorch cache and system mem 2023-08-12 18:52:57 -07:00
Aman Karmani
e029ab34ea quiet noise from llama tokenizer by setting pad token earlier 2023-08-12 18:31:40 -07:00
Aman Karmani
8cec513447 extract module for working with cfg 2023-08-12 18:25:27 -07:00
Aman Karmani
a13e45d548 fix DefaultDict.__or__ 2023-08-13 01:15:50 +00:00
Wing Lian
918f1b0dfb revert previous change and build ax images w docker on gpu (#371) 2023-08-12 20:23:00 -04:00
Wing Lian
c3fde36ada attempt to run non-base docker builds on regular cpu hosts (#369) 2023-08-12 19:07:38 -04:00
Wing Lian
2bb0b78975 Attention mask and position id fixes for packing (#285)
* fix attetion mask with packing

* set position ids and use block diagonal attn mask

* fix expand mask for multiple batch items, make sure we pad position_ids

* don't move masks to cpu

* use multi pack dataloader w random sampler

* add position_ids back

* more fixes for dataloader integration

* est total tokens, fix field loop

* more fixes, position_ids seems broken

* more fixes for sample packing

* use distributed sampler, avoid accelerate prepare

* use accelerator prepare for dataloader

* fix for position_ids w packing

* Update src/axolotl/utils/dataloader.py

* validation for sample packing and doc

* more fixes for 4k and optimizations

* optimized expand mask fn

* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

* fix rounding of len of batches to int

* better handling so that all devices have the same dataloader len

* fix step calc for packing

* pass sample packing efficiency to training args

* add a test for the mask expansion for sequence packing

* only process eval dataset for packing if not None

* don't split batches when packing

* weighted CE losses

* weighted CEL fixes

* limit packing to sequences of max seq len

* seq_len_multiple for packing

* make sure the chunk size is an int

* sample_packing_seq_len_multiplier config

* use cumulative seq len with var len flash attn v2 w packing

* properly calculate max len

* fix flash-attn, xformers, packing, support chatml

* fix chatml system prompt for openorca, legacy tokenizer opts

* add chatml

* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test

* fix test and pylint checks

* more packing and dataset optimizations and fixes

* filter w multiple cpus

* more fixes and optimizations

* fixes and go back to distributed sampler since batch sampler won't work

* fix counts by accounting for num devices

* fix steps calculation

* previous accelerate is still most performant

* add numba to requirements.

* use custom distributed checks

* fix sampler to prevent overfit w new epochs

* let's not cleanup the cached datasets

* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

* speed optimizations and set accelerate fsdp env vars

* optimize dataset concatenation?

* more optimizations for dataset handling

* fix import for annotation

* manual pre-commit fixes

* another sum optimization and bug fix for calc steps

* fix packing estimations

* fix formatting

* pylint problems

* add back flash attention branch for handling unpacked sequences seperately

* Address PR feedback

* add optional sample packing config params to readme
2023-08-12 15:14:56 -04:00
NanoCode012
a276c9c88d Fix(save): Save as safetensors (#363) 2023-08-13 01:22:52 +09:00
Morgan McGuire
7019509daa Add wandb_entity to wandb options, update example configs, update README (#361)
* Update wandb_entity and add wandb descriptions

* add wandb to config section

* remove trailing whitespace for pre-commit hook

* remove trailing whitespace for pre-commit hook

---------

Co-authored-by: Morgan McGuire <morganmcguire@Morgans-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-12 12:17:11 -04:00
NanoCode012
96bd6ae1c4 Fix(model loading): Warn when model revision is passed to gptq (#364)
* fix(model loading): warn when model revision is passed to gptq

* chore: improve message
2023-08-13 01:16:59 +09:00
NanoCode012
e37d9358e6 Fix(message): Improve error message for bad format (#365) 2023-08-13 01:16:18 +09:00
NanoCode012
b5212068ac Feat: Add rope scaling (#343)
* Feat: Add rope scaling

* fix: move rope config
2023-08-13 00:50:15 +09:00
NanoCode012
289d5c403d feat(merge): save tokenizer on merge (#362) 2023-08-13 00:18:10 +09:00
Aman Gupta Karmani
35c8b90306 Merge pull request #355 from tmm1/bitsandbytes-fixes
bump to latest bitsandbytes release with major bug fixes
2023-08-11 15:15:38 -07:00
NanoCode012
fae6ed8092 Update README.md on pretraining_dataset (#360)
* Update README.md on pretraining_dataset

* Fix message
2023-08-11 12:17:07 +09:00
NanoCode012
94d03c8402 Clarify pre-tokenize before multigpu (#359) 2023-08-11 11:27:42 +09:00
Aman Gupta Karmani
11ddccb80f Merge pull request #356 from tmm1/load_model-args
simplify `load_model` signature
2023-08-09 18:24:34 -07:00
Aman Gupta Karmani
964312199e Merge pull request #354 from tmm1/gpu-util
GPU memory usage logging
2023-08-09 15:44:18 -07:00
Aman Karmani
718102271f simplify load_model signature 2023-08-09 22:36:02 +00:00
Aman Gupta Karmani
f5c11f8262 Merge pull request #350 from tmm1/group-len-false-examples
set `group_by_length` to false in all examples
2023-08-09 14:48:48 -07:00
Aman Karmani
fce40aab23 bump to latest bitsandbytes release with major bug fixes 2023-08-09 21:47:11 +00:00
Aman Karmani
9c314101d5 use newer pynvml package 2023-08-09 21:06:28 +00:00
Aman Karmani
e303d64728 log GPU memory usage 2023-08-09 18:26:28 +00:00
Aman Karmani
b4d1d22782 note pattern when using groups 2023-08-07 16:18:42 -07:00
Aman Karmani
9f99104038 update comment for group_by_length 2023-08-07 01:04:56 -07:00
Aman Karmani
36fefcf94b set group_by_length to false in examples 2023-08-06 23:59:09 -07:00
Wing Lian
176b888a63 ensure enable_input_require_grads is called on model before getting the peft model (#345) 2023-08-06 18:13:10 -04:00
Jan Philipp Harries
3392270544 experimental llama 2 chat support (#296)
* experimental llama 2 chat support

* few small fixes

* llama2_chat

* small fix to follow original implementation

* small fixes and added fixtures/tests

* fix -mixed up inference and finetuning conversations

* args - small fix

* small fix

* small adjustment and warning

* fix with pre-commit

---------

Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
2023-08-06 17:40:52 -04:00
Wing Lian
bb53a165f5 add a basic ds zero3 config (#347)
better defaults for ds
2023-08-06 17:19:51 -04:00
ssmi153
10405b9995 Update XFormers Attention Monkeypatch to handle Llama-2 70B (GQA) (#339)
* Fix XFormers attention for Llama-2 70B (GQA)

Updated XFormers MonkeyPatch to handle GQA as used in Llama-2 70B. All the updated code is taken directly from the Transformers library: 07360b6c9c (diff-06392bad3b9e97be9ade60d4ac46f73b6809388f4d507c2ba1384ab872711c51) from their llama_modeling.py file.

* Catch configs without pretraining_tp

* Whitespace bug fix

Command had accidentally been moved out of if-else block.

* pre-commit formatting fixes

Thanks to @winglian
2023-08-06 11:09:04 -04:00
Jan Philipp Harries
c93655c0a3 Added Orca Mini prompt strategy (#263)
* added Orca Mini prompt strategy

* maybe this fixed precommit errors?

* pre-commits passing

---------

Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
2023-08-06 03:16:41 +09:00
Wing Lian
fe285430bc optimize the iteration when tokenizeing large datasets (#332) 2023-08-04 12:12:05 -04:00
Aman Gupta Karmani
0d2e34f056 Merge pull request #336 from tmm1/flash-attn
Fix flash-attn + qlora not working with llama models
2023-08-03 16:25:30 -07:00
Aman Gupta Karmani
b56a6c0101 Merge pull request #337 from tmm1/readme-fix
update README
2023-08-03 15:14:17 -07:00
Aman Karmani
2eda9e02a9 fix typo 2023-08-03 21:04:12 +00:00
Aman Karmani
78b9efb7f4 scope flash-attn+qlora fix correctly, scope to llama, add comment 2023-08-03 19:19:39 +00:00
Aman Karmani
312a9fad07 move flash-attn monkey patch alongside the others 2023-08-03 17:20:49 +00:00
Aman Karmani
58d665943e python 3.10 and 3.11 both work fine, as does pytorch 2.1.0.dev 2023-08-03 16:47:25 +00:00
Aman Karmani
cc7e80026e there is no configs folder 2023-08-03 16:31:37 +00:00
mhenrichsen
dc71d8872a feat/llama-2 examples (#319)
* qlora llama-2

* qlora llama-2

* linting

* readme

* lora added

* linting

* change group_by_length

* 13b fitting on 24gb

* grouped lengths true

* add pad token

* change out dir

---------

Co-authored-by: Mads Henrichsen <mads@Brbar-tilhrende-Mads.local>
2023-08-03 19:22:48 +09:00
Aman Karmani
248bf90f89 ensure flash-attn fixes happen in both adapter/lora modes, and use torch_dtype 2023-08-02 20:15:03 +00:00
Wing Lian
77085ea24e qlora w flash attention fixes (#333) 2023-08-01 23:26:16 -04:00
Wing Lian
db2a3586f3 add peft install back since it doesn't get installed by setup.py (#331) 2023-07-31 16:31:53 -04:00
Wing Lian
6c9a87c8ee pin accelerate so it works with llama2 (#330) 2023-07-30 22:20:06 -04:00
Wing Lian
894cba09f3 fix FSDP save of final model (#329) 2023-07-30 21:46:44 -04:00
Wing Lian
41a4d15d43 update README for updated docker images (#328)
* update README for updated docker images

* update readme from pr feedback
2023-07-28 16:50:03 -04:00
Wing Lian
2c37bf6c21 Prune cuda117 (#327)
* drop cuda117/torch 1.13.1 from support, pin flash attention to v2.0.1, rm torchvision/torchaudio install

* gptq base build not needed. add sm 9.0 support
2023-07-26 16:27:49 -04:00
Wing Lian
9f69c4d8c1 latest HEAD of accelerate causes 0 loss immediately w FSDP (#321) 2023-07-24 11:23:56 -04:00
Wing Lian
3d4984b9a5 update prompts for open orca to match the paper (#317)
fix the test for the updated system tokenizer
2023-07-22 13:49:11 -04:00
Wing Lian
ff7f18d1ed disable gh cache for first step of docker builds too 2023-07-22 11:46:37 -04:00
Wing Lian
cf62cfd661 add runpod envs to .bashrc, fix bnb env (#316)
* hopper support for base dockerfile, add runpod envs to .bashrc

* set BNB_CUDA_VERSION env for latest bnb

* don't support hopper yet w 118
2023-07-22 10:09:38 -04:00
Wing Lian
c5df969262 don't use the gha cache w docker 2023-07-22 08:46:21 -04:00
Wing Lian
40a53ff181 Merge pull request #307 from OpenAccess-AI-Collective/xgen-user-sharegpt-tokens
better handling since xgen tokenizer breaks with convert_tokens_to_ids
2023-07-22 04:10:38 -04:00
Wing Lian
dcdec44347 Merge pull request #306 from ethanhs/xgen
Add XGen info to README and example config
2023-07-22 04:10:18 -04:00
Wing Lian
3ffb018a4c Merge pull request #313 from OpenAccess-AI-Collective/tokenizer-llama2-embeddings
don't resize embeddings to multiples of 32x by default
2023-07-22 04:09:59 -04:00
Wing Lian
a94f2eecb1 Merge pull request #299 from OpenAccess-AI-Collective/flash-attention-2
Flash attention 2
2023-07-22 04:07:48 -04:00
Wing Lian
1066751358 don't resize embeddings to multiples of 32x by default 2023-07-22 01:52:38 -04:00
Wing Lian
1b63bf13bc Merge pull request #308 from OpenAccess-AI-Collective/apache2-license
add apache 2.0 license
2023-07-21 09:50:14 -04:00
Wing Lian
5cce2a42ff add apache 2.0 license 2023-07-21 09:49:29 -04:00
Wing Lian
2a428e8014 better handling since xgen tokenizer breaks with convert_tokens_to_ids 2023-07-21 09:24:11 -04:00
Wing Lian
cdf85fdbd5 pin flash attention 2 to the fix for backwards pass 2023-07-21 08:18:53 -04:00
Wing Lian
9b790d359b flash attention 2 2023-07-21 08:17:46 -04:00
Ethan Smith
38811434e6 Add XGen info to README and example config 2023-07-21 00:44:50 -07:00
NanoCode012
06c61d6f13 Merge pull request #304 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix(readme): Improve wording for push model
2023-07-21 13:39:45 +09:00
Wing Lian
262dc29df2 Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201
Pytorch 2.0.1
2023-07-21 00:28:38 -04:00
NanoCode012
165907fddb Fix(readme): Improve wording for push model 2023-07-21 11:28:35 +09:00
Wing Lian
a032c9f452 fix sdp attention to use the flash/mem-efficient context manaager 2023-07-20 01:05:48 -04:00
Wing Lian
b06d3e3645 explicitly pin flash attention 1 to v1.0.9 2023-07-20 01:02:08 -04:00
Wing Lian
c58034d48c use pytorch 2.0.1 2023-07-20 00:47:13 -04:00
NanoCode012
28fd429bcf Merge pull request #293 from NanoCode012/fix/tokenize-speed
Fix(tokenizing): Use multi-core
2023-07-19 11:02:04 +09:00
NanoCode012
45ac7c4f88 feat: use multi-core 2023-07-19 10:16:54 +09:00
Wing Lian
edd6980dd9 Merge pull request #289 from OpenAccess-AI-Collective/hf_transfer
add hf_transfer to requirements for faster hf upload
2023-07-17 15:08:06 -04:00
Wing Lian
dc6d25124d Merge pull request #288 from OpenAccess-AI-Collective/NanoCode012-patch-1
fix(readme): remove accelerate config
2023-07-17 14:46:43 -04:00
Wing Lian
6dd2e7d671 add hf_transfer to requirements for faster hf upload 2023-07-17 14:44:48 -04:00
NanoCode012
b64f411849 fix(readme): remove accelerate config 2023-07-18 01:31:02 +09:00
Wing Lian
03a59c1ed4 Merge pull request #287 from OpenAccess-AI-Collective/dataclass-fix
fix axolotl training args dataclass annotation
2023-07-17 06:09:23 -04:00
Wing Lian
ebaec3c406 fix axolotl training args dataclass annotation 2023-07-17 04:57:02 -04:00
Wing Lian
73e70e3996 Merge pull request #286 from OpenAccess-AI-Collective/logging-docker-fixes
misc fixes
2023-07-17 04:26:39 -04:00
Wing Lian
d75adb9835 misc fixes 2023-07-17 03:00:27 -04:00
Wing Lian
02224668c3 Merge pull request #283 from OpenAccess-AI-Collective/docker-git-fetch
git fetch fix for docker
2023-07-17 02:17:00 -04:00
Wing Lian
f162f3c7cc set transformers cache env var in docker image 2023-07-16 23:03:54 -04:00
Wing Lian
eca3531329 git fetch fix for docker 2023-07-16 22:25:05 -04:00
Wing Lian
6f16c4569d Merge pull request #276 from theobjectivedad/logging_enhancement
Logging update: added PID and formatting
2023-07-16 17:04:52 -04:00
Wing Lian
0bd09c077d Merge pull request #280 from teknium1/main
Update requirements.txt
2023-07-16 16:08:58 -04:00
Wing Lian
469c08c9ba Merge pull request #279 from NanoCode012/feat/multi-gpu-readme
Feat(readme): improve docs on multi-gpu
2023-07-16 16:08:37 -04:00
Wing Lian
334af625d0 Merge pull request #277 from cg123/dataset-name
Allow non-default dataset configurations
2023-07-16 16:08:15 -04:00
Teknium
273b3a3aa7 Update requirements.txt
Require latest git accelerate to fix saving checkpoint issue
2023-07-16 10:24:24 -07:00
Charles Goddard
3cdd8e4122 Add dataset name to all yaml options in README 2023-07-15 13:17:37 -07:00
NanoCode012
cf5ae6b649 Feat(readme): improve docs on multi-gpu 2023-07-16 01:07:27 +09:00
theobjectivedad
b1f4f7a34d Fixed pre-commit problems, fixed small bug in logging_config to handle LOG_LEVEL env var 2023-07-15 12:29:35 +00:00
The Objective Dad
83237b8445 Merge branch 'OpenAccess-AI-Collective:main' into logging_enhancement 2023-07-15 06:16:04 -05:00
Charles Goddard
46032a1a1f Fix formatting mistake 2023-07-14 20:57:27 -07:00
Charles Goddard
8bba64258e Add example of dataset with configuration name to README 2023-07-14 20:46:21 -07:00
Charles Goddard
88089e8b32 Add ability to pass 'name' argument to load_dataset 2023-07-14 16:46:39 -07:00
NanoCode012
168a7a09cc Merge pull request #274 from OpenAccess-AI-Collective/NanoCode012-patch-2
Feat: Set push to hub as private by default
2023-07-14 23:15:47 +09:00
NanoCode012
231031a0e1 Merge pull request #275 from NanoCode012/feat/safetensors
Feat: Add save_safetensors
2023-07-14 23:07:26 +09:00
theobjectivedad
9234b75cb4 Update log message format, IMO this is easier to read. 2023-07-14 07:36:21 -05:00
theobjectivedad
553a86b52c Adding logging enhancement 2023-07-14 07:26:19 -05:00
NanoCode012
5daf7d5299 Merge pull request #273 from OpenAccess-AI-Collective/NanoCode012-patch-1
Feat(docs): Add model_revision arg
2023-07-14 21:09:50 +09:00
NanoCode012
5491278a79 Feat: Add save_safetensors 2023-07-14 13:21:47 +09:00
NanoCode012
1514739f0f Set push to hub as private by default 2023-07-14 13:17:49 +09:00
NanoCode012
896c1aebcf Feat(docs): Add model_revision arg 2023-07-14 12:56:07 +09:00
Wing Lian
ef17e15483 Merge pull request #272 from OpenAccess-AI-Collective/model-revision
support for loading a model by git revision
2023-07-13 23:12:00 -04:00
Wing Lian
69a235061b support for loading a model by git revision 2023-07-13 22:58:25 -04:00
Wing Lian
687d889928 Merge pull request #271 from OpenAccess-AI-Collective/quadratic-warmup
Quadratic warmup
2023-07-10 12:48:02 -04:00
Wing Lian
c4cf567b55 Merge branch 'main' into quadratic-warmup 2023-07-10 12:42:12 -04:00
Wing Lian
c49729d2bc better configuration for quadratic warmup 2023-07-10 11:52:59 -04:00
Wing Lian
13ac4d8de2 Merge pull request #268 from OpenAccess-AI-Collective/fix-adam-args
params are adam_*, not adamw_*
2023-07-08 12:33:34 -04:00
Wing Lian
19cf0bda99 params are adam_*, not adamw_* 2023-07-08 12:13:39 -04:00
Wing Lian
f74edd5b56 Merge pull request #266 from OpenAccess-AI-Collective/trust-remote-no-llama 2023-07-07 21:38:11 -04:00
Wing Lian
d69da99c2c skip explicit model type too if using trust_remote_code 2023-07-07 21:33:11 -04:00
Wing Lian
66afb76a15 don't use llama if trust_remote_code is set since that needs to use AutoModel path 2023-07-07 21:31:02 -04:00
NanoCode012
a692ad3f4c Merge pull request #264 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix(readme): local path loading and custom strategy type
2023-07-06 23:34:57 +09:00
NanoCode012
41da98b982 Fix for linter 2023-07-06 23:20:11 +09:00
NanoCode012
9e64f42e0f Fix local path loading and custom strategy type 2023-07-06 23:08:09 +09:00
Wing Lian
b9b7d4ce92 Merge pull request #221 from utensil/local_dataset
[WIP] Support loading data files from a local directory
2023-07-03 09:10:13 -04:00
Wing Lian
9bed281867 Merge pull request #258 from NanoCode012/fix/deprecate-push
Fix future deprecation push_to_hub_model_id
2023-07-03 09:08:26 -04:00
NanoCode012
e79c8e617e Fix future deprecation push_to_hub_model_id 2023-07-03 12:44:29 +09:00
Wing Lian
71456955f5 pin pydantic so deepspeed isn't broken 2023-07-02 22:26:51 -04:00
Wing Lian
3a783c04e4 Merge pull request #247 from OpenAccess-AI-Collective/fix-apex-base
update pip install command for apex
2023-07-01 06:18:25 -04:00
Wing Lian
1e5014acec Merge pull request #255 from OpenAccess-AI-Collective/open-orca-prompts
open orca support
2023-07-01 01:11:23 -04:00
Wing Lian
a10da1caff 11.7.0 nvidia/cuda docker images are deprecated, move to 11.7.1
Some checks failed
ci-cd-base / build-base (<nil>, 117, 11.7.1, 3.9, 1.13.1) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.10, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (gptq, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-01 00:29:07 -04:00
Wing Lian
4066c78631 Merge pull request #246 from OpenAccess-AI-Collective/sys-prompts-instruct
add option for instruct w sys prompts
2023-07-01 00:27:29 -04:00
Wing Lian
78a1e1fa12 open orca support 2023-07-01 00:19:41 -04:00
NanoCode012
bc8a2e5547 Merge pull request #249 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix typing list in prompt tokenizer
2023-06-30 15:01:41 +09:00
NanoCode012
910ebe47f5 Merge pull request #252 from OpenAccess-AI-Collective/NanoCode012-readme-fix
Add cfg.push_to_hub_model_id to readme
2023-06-30 14:56:55 +09:00
NanoCode012
c146880a75 Update README.md 2023-06-30 11:33:53 +09:00
NanoCode012
77bdb7d144 Fix typing list 2023-06-29 14:29:55 +09:00
Wing Lian
530809fd74 update pip install command for apex 2023-06-28 22:36:28 -04:00
Wing Lian
924bbfddec add option for instruct w sys prompts 2023-06-28 22:27:17 -04:00
Wing Lian
f150c027e3 Merge pull request #224 from OpenAccess-AI-Collective/system-prompt-data
System prompt data
2023-06-27 17:57:43 -04:00
Wing Lian
5c39c006c9 Merge pull request #244 from OpenAccess-AI-Collective/push-to-hub
push intermediate model checkpoints to hub
2023-06-27 17:57:30 -04:00
Wing Lian
612aabd8c4 push intermediate model checkpoints to hub 2023-06-27 15:40:25 -04:00
Wing Lian
af05883f75 Merge pull request #243 from OpenAccess-AI-Collective/unprompted-instruct
skip the system prompt
2023-06-25 22:50:35 -04:00
Wing Lian
05ab9092e3 skip the system prompt 2023-06-25 22:40:50 -04:00
Wing Lian
7b57ed7618 pylint for duplicated code for system prompts 2023-06-25 22:28:07 -04:00
Wing Lian
3a38271276 add tests and supoort for loader for sys prompt data 2023-06-25 22:28:07 -04:00
Wing Lian
8d20e0a3d3 initial wip to get sys prompt from dataset 2023-06-25 22:28:07 -04:00
Wing Lian
de8ed229c3 Merge pull request #240 from OpenAccess-AI-Collective/tokenizer-fast
optionally define whether to use_fast tokenizer
2023-06-25 12:47:55 -04:00
Wing Lian
478d8c7b8e Merge pull request #241 from OpenAccess-AI-Collective/py3-pre-commit
better py3 support w pre-commit
2023-06-25 12:47:02 -04:00
Wing Lian
645c13592c better py3 support w pre-commit 2023-06-25 10:26:02 -04:00
Wing Lian
47d601fa23 optionally define whether to use_fast tokenizer 2023-06-25 10:19:49 -04:00
Wing Lian
756dfba97b Merge pull request #218 from OpenAccess-AI-Collective/no-fail-fast
don't fail fast
2023-06-23 15:42:54 -04:00
Wing Lian
91ab0592af Merge pull request #235 from msinha251/Fixing-data-readme 2023-06-23 13:52:01 -04:00
Mahesh Sinha
0aeb7c7802 Fixing Data Readme 2023-06-21 15:34:48 +02:00
Utensil
9bdd30cdfd Support loading data files from a local directory
ref:  https://huggingface.co/docs/datasets/v2.13.0/en/package_reference/loading_methods#datasets.load_dataset.path
2023-06-21 08:00:58 +00:00
Wing Lian
d35278aaf1 don't fail fast 2023-06-15 16:01:27 -04:00
Wing Lian
9492d4ebb7 Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperparams-cfg
support adamw and grad norm hyperparams
2023-06-15 12:20:55 -04:00
Wing Lian
ad5ca4f734 Additional test case per pr 2023-06-15 10:12:47 -04:00
Wing Lian
cb9d3af5c0 add validation and tests for adamw hyperparam 2023-06-15 09:39:42 -04:00
Wing Lian
c969f0a9dc add docs 2023-06-15 08:43:20 -04:00
Wing Lian
6d0ee4ba34 support adamw and grad norm hyperparams 2023-06-15 08:40:41 -04:00
Wing Lian
a81f52d575 Merge pull request #212 from OpenAccess-AI-Collective/doc-20230615-v1
add float16 docs and tweak typehints
2023-06-15 08:28:57 -04:00
Wing Lian
1925eaf1e6 Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
2023-06-15 08:13:43 -04:00
Wing Lian
1ab3bf3e67 fix test name 2023-06-15 02:09:33 -04:00
Wing Lian
d7635b7148 hint to what AMP means 2023-06-15 02:06:27 -04:00
Wing Lian
88e17ffc50 add float16 docs and tweak typehints 2023-06-15 02:05:31 -04:00
Wing Lian
baed440fa1 ingore duplicate code in tests 2023-06-15 02:03:53 -04:00
Wing Lian
7925ddce86 bugfix for potential off by one 2023-06-15 01:59:33 -04:00
Wing Lian
6f849809c5 Merge pull request #206 from MaciejKarasek/issue205
issue #205 bugfix
2023-06-14 14:23:38 -04:00
Wing Lian
c16644d05e Merge pull request #209 from sroecker/fix_redpajama_example_tokenizer
Use AutoTokenizer for redpajama example
2023-06-14 14:23:21 -04:00
Steffen Röcker
945c4191a3 Use AutoTokenizer for redpajama example 2023-06-14 20:09:26 +02:00
maciej.karasek
136522f9c9 style correction 2023-06-14 20:02:09 +02:00
maciej.karasek
556fe408b3 issue #205 bugfix 2023-06-14 16:59:57 +02:00
Wing Lian
16bb6276a5 Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum
add support for opimum bettertransformers
2023-06-14 07:50:15 -04:00
NanoCode012
06674a11f2 Merge pull request #202 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix sharegpt type in doc
2023-06-14 09:48:35 +09:00
NanoCode012
3513885f43 Fix sharegpt type 2023-06-14 01:10:58 +09:00
Wing Lian
06652c1c39 Merge pull request #196 from OpenAccess-AI-Collective/openllama-ft-config
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
tweak config to work
2023-06-13 11:51:04 -04:00
NanoCode012
068fc48978 Merge pull request #199 from NanoCode012/chore/prompter-arg
chore: Refactor inf_kwargs out
2023-06-13 17:56:22 +09:00
Wing Lian
aaadacf6b3 Merge pull request #200 from PocketDocLabs/main
Update README.md to include a community showcase
2023-06-13 04:44:34 -04:00
PocketDoc Labs
5ff547dc70 Update README.md to include a community showcase 2023-06-12 22:38:10 -07:00
NanoCode012
dc77c8ebce chore: Refactor inf_kwargs out 2023-06-13 12:01:46 +09:00
NanoCode012
51a4c12242 Merge pull request #197 from mhenrichsen/chore/update-readme
chore: Fix inference README.
2023-06-13 11:53:26 +09:00
Wing Lian
4b43a66a0b update alpaca_chat prompts for instructions to explainn the conversation 2023-06-12 18:38:38 -04:00
mhenrichsen
34ae69989f fix inference 2023-06-12 21:39:19 +02:00
Wing Lian
7dc580b837 add axolotl trainer and quadratic warmup 2023-06-12 13:16:40 -04:00
Wing Lian
fd2c9814c9 Merge branch 'main' into flash-optimum 2023-06-12 13:12:15 -04:00
Wing Lian
2ba4ae8f46 tweak config to work 2023-06-12 10:07:18 -04:00
Wing Lian
93dacba228 Merge pull request #187 from OpenAccess-AI-Collective/strip-peft-device-map
peft no longer needs device_map
2023-06-12 09:10:49 -04:00
Wing Lian
8002ffb41f Merge pull request #177 from NanoCode012/fix/landmark-patch
Fix landmark attention patch
2023-06-12 08:27:12 -04:00
Wing Lian
74ef5cc083 Merge pull request #192 from OpenAccess-AI-Collective/sharegpt-custom-prompt
misc fixes
2023-06-12 08:26:38 -04:00
Wing Lian
5e616d91c0 Merge branch 'main' into strip-peft-device-map 2023-06-12 08:25:54 -04:00
Wing Lian
94f310c7a6 Merge pull request #193 from OpenAccess-AI-Collective/config-fixes-20230612
config fixes
2023-06-12 08:24:52 -04:00
NanoCode012
8e568bbdae Merge pull request #159 from AngainorDev/patch-1
Fix training over existing lora
2023-06-12 20:27:11 +09:00
NanoCode012
e21dab49fd Merge pull request #194 from NanoCode012/fix/config-path
Fix config path after config moved
2023-06-12 19:28:12 +09:00
NanoCode012
52cde69288 Fix config path after config moved 2023-06-12 17:06:15 +09:00
Wing Lian
9a58e99e81 config fixes 2023-06-12 01:52:58 -04:00
Wing Lian
c7dee56b87 add typehints 2023-06-11 19:52:34 -04:00
Wing Lian
aac4b7691e add new sharegpt, refactor prompt so it can be customized later, add exception if no data is processed 2023-06-11 19:42:25 -04:00
NanoCode012
f31a338cbb Merge pull request #191 from OpenAccess-AI-Collective/NanoCode012-patch-1
Add save_steps and eval_steps to Readme
2023-06-12 02:55:37 +09:00
NanoCode012
4cd1deeef2 Add save_steps and eval_steps to Readme 2023-06-12 02:44:46 +09:00
Wing Lian
9ac16ed8d1 Merge pull request #190 from OpenAccess-AI-Collective/fixes-20230711-v2
more config pruning and migrating
2023-06-11 13:27:08 -04:00
Wing Lian
6b3f509d9e forgot to add this file 2023-06-11 11:50:12 -04:00
Wing Lian
336aa3fd48 gptq lora llama is obviously good 2023-06-11 11:05:29 -04:00
Wing Lian
d0d7eaa4f3 update openllama and clean up paths 2023-06-11 11:03:31 -04:00
Wing Lian
a6ebf57e82 fix table formatting 2023-06-11 10:55:32 -04:00
Wing Lian
280832cec2 more matrix updates 2023-06-11 10:52:36 -04:00
Wing Lian
a43bae9ff0 update the support matrix 2023-06-11 10:44:03 -04:00
Wing Lian
effbbf6dd1 more pruning 2023-06-11 10:38:24 -04:00
Wing Lian
c9a149f9e8 add check for attr 2023-06-11 10:11:17 -04:00
Wing Lian
c530e4b9c8 more config pruning and migrating 2023-06-11 10:09:05 -04:00
Wing Lian
f620706776 Merge pull request #189 from OpenAccess-AI-Collective/fixes-20230711
various fixes
2023-06-11 09:49:23 -04:00
Wing Lian
77762a5d6b get rid of some configs, formalize pythioa lora config 2023-06-11 09:41:41 -04:00
Wing Lian
14668fa54e new validation for mpt w grad checkpoints 2023-06-11 09:26:10 -04:00
AngainorDev
b565ecf0a1 Fix strict and Lint 2023-06-11 15:23:38 +02:00
Wing Lian
fe0b76854e match up gradient checkpointing when using lora w config 2023-06-11 09:20:40 -04:00
NanoCode012
e944311442 Merge pull request #186 from akj2018/main
Update FAQS.md
2023-06-11 19:45:06 +09:00
Akshay Jain
e3e7b52a5b Update FAQS.md
Converted (```) to single backtick (') uniformly.
2023-06-10 23:36:14 -07:00
NanoCode012
974dc00a7d Fix set mem_id for inference and refactor 2023-06-11 14:00:54 +09:00
NanoCode012
572d1141e6 Set mem cache args on inference 2023-06-11 12:05:37 +09:00
NanoCode012
a6190c8094 Clean up landmark patching 2023-06-11 11:59:03 +09:00
NanoCode012
563b6d89e6 Fix undefined LlamaForCausalLM and del try except 2023-06-11 11:58:31 +09:00
Wing Lian
cd0a6f6027 peft no longer needs device_map 2023-06-10 22:50:09 -04:00
Akshay Jain
0e664a5ebc Update FAQS.md
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 19:26:12 -07:00
Akshay Jain
dd7d16d2eb Update FAQS.md
Updated FAQS.md with backticks around error message
2023-06-10 19:15:50 -07:00
NanoCode012
e285e24f7f Address PR suggestion
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-06-11 10:52:12 +09:00
NanoCode012
919727b4d7 Refactor landmark attention patch 2023-06-11 10:51:05 +09:00
Akshay Jain
5ffefee37f Update FAQS.md
Update FAQS.md with the following statement

Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c
/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598:  arrow::fs::FinalizeS3 was not called even though S3 was initialized.  This could lead to a segmentation fault at exit

try reinstalling bitsandbytes and transformers from source
2023-06-10 18:34:54 -07:00
Wing Lian
d9f713e4e3 Merge pull request #183 from OpenAccess-AI-Collective/inference-from-stdin
pass a prompt in from stdin for inference
2023-06-10 17:06:55 -04:00
Wing Lian
958da70376 fix formatting 2023-06-10 15:28:08 -04:00
Wing Lian
c4e4f8115c pass a prompt in from stdin for inference 2023-06-10 15:07:40 -04:00
Angainor Development
a808bf913f Fix missing cfg. 2023-06-10 20:28:49 +02:00
Wing Lian
01248253a3 Merge pull request #182 from OpenAccess-AI-Collective/fix-llama-ref
fix for local variable 'LlamaForCausalLM' referenced before assignment
2023-06-10 14:25:51 -04:00
Wing Lian
759e8673ce Update scripts/finetune.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 14:25:21 -04:00
Wing Lian
0c6f928601 address PR feedback 2023-06-10 14:23:56 -04:00
Wing Lian
eea2731a5e add streaming dataset support for pretraining datasets 2023-06-10 14:23:56 -04:00
Wing Lian
1db46a9c72 linting fix 2023-06-10 14:23:56 -04:00
Wing Lian
ab5cd28acf more gpt-neox long ctx fixes 2023-06-10 14:23:55 -04:00
Wing Lian
1a82082e91 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-10 14:23:55 -04:00
Wing Lian
1210dc8fd5 more tweaks to do pre-training with bettertransformers 2023-06-10 14:23:55 -04:00
Wing Lian
488a67d75a experimental expansion of ctx len 2023-06-10 14:23:53 -04:00
Wing Lian
71a43f8479 add validation/warning for bettertransformers and torch version 2023-06-10 14:22:31 -04:00
Wing Lian
39619028a3 use pythia-12b, neox-20b is flaky 2023-06-10 14:22:30 -04:00
Wing Lian
8792199799 add flash attn context for efficient training and attempt setting model to train mode: 2023-06-10 14:22:30 -04:00
Wing Lian
1edc30c786 add support for opimum bettertransformers 2023-06-10 14:22:30 -04:00
Wing Lian
14163c15d9 fix for local variable 'LlamaForCausalLM' referenced before assignment 2023-06-10 14:11:13 -04:00
Wing Lian
41e4f6ca31 Merge pull request #181 from OpenAccess-AI-Collective/xpos-rope
add support to extend context with xpos rope
2023-06-10 14:04:03 -04:00
Angainor Development
79e2a6f140 Merge branch 'main' into patch-1 2023-06-10 19:07:54 +02:00
Angainor Development
c2508987a6 Remove explicit definition of cfg.inference 2023-06-10 19:06:10 +02:00
Wing Lian
215d775147 Merge pull request #180 from Glavin001/feat/stream-inference
Add streaming inference & fix stopping at EOS
2023-06-10 12:04:34 -04:00
Wing Lian
f36e227eaf formatting for linter 2023-06-10 12:00:52 -04:00
Wing Lian
5878bb1f3a add option to readme 2023-06-10 11:57:41 -04:00
Wing Lian
a03a7d7d8b add support to extend context with xpos rope 2023-06-10 10:29:46 -04:00
Glavin Wiechert
fec6bcc3e6 Add streaming inference & fix stopping at EOS 2023-06-10 08:14:47 +00:00
Wing Lian
931e606459 Merge pull request #179 from OpenAccess-AI-Collective/fix-max_seq_len
fix for max sequence len across different model types
2023-06-09 20:52:03 -04:00
Wing Lian
7f09106437 fix for max sequence len across different model types 2023-06-09 20:42:33 -04:00
NanoCode012
6b50200234 Merge pull request #178 from PocketDocLabs/main
Update README.md to reflect current gradient checkpointing support
2023-06-10 08:26:48 +09:00
PocketDocLabs
16f9e28048 Update README.md to reflect current gradient checkpointing support
Previously the readme stated gradient checkpointing was incompatible with 4-bit lora in the current implementation however this is no longer the case. I have replaced the warning with a link to the hugging face documentation on gradient checkpointing.
2023-06-09 16:10:58 -07:00
NanoCode012
b9083a7fc1 Merge pull request #176 from NanoCode012/fix/peft-import
Fix backward compat for peft
2023-06-10 07:56:35 +09:00
NanoCode012
aefb2fc681 Fix backward compat for peft 2023-06-10 07:46:36 +09:00
NanoCode012
b5aa8d854c Merge pull request #169 from NanoCode012/feat/landmark
Feat: Add landmark attention
2023-06-10 07:26:06 +09:00
NanoCode012
4d6490bce2 Merge pull request #171 from OpenAccess-AI-Collective/NanoCode012-falcon-lora-matrix
Fix falcon support lora
2023-06-09 17:58:22 +09:00
NanoCode012
b242b69e10 Fix falcon support lora 2023-06-09 17:50:16 +09:00
NanoCode012
320beb20f4 Merge pull request #170 from OpenAccess-AI-Collective/NanoCode012-lambdalabs-fix
Feat: Improve lambda labs instruction
2023-06-09 16:52:27 +09:00
Angainor Development
bd3b537344 Feed cfg.inference 2023-06-09 08:59:05 +02:00
Angainor Development
813cfa4c14 WIP: Rely on cfg.inference 2023-06-09 08:49:32 +02:00
NanoCode012
2e13ceff37 Improve lambda labs instruction 2023-06-09 15:03:08 +09:00
NanoCode012
2a801b001a Fix grad checkpoint and outputs param 2023-06-09 14:28:44 +09:00
NanoCode012
e44c9e0b3e Fix patching via import instead of hijacking 2023-06-09 14:27:24 +09:00
NanoCode012
55b8542de8 Feat: Add landmark attention 2023-06-09 12:54:08 +09:00
Wing Lian
febe902517 Merge pull request #168 from bratao/main
Disable Wandb if no wandb project is specified
2023-06-08 22:05:56 -04:00
Bruno Cabral
f4df266842 Disable Wandb 2023-06-08 21:02:02 -03:00
NanoCode012
281dc3df59 Merge pull request #167 from NanoCode012/fix/redundant-save-eval-steps
Fix: Refactor out unmodified save_steps and eval_steps
2023-06-09 01:39:33 +09:00
NanoCode012
2ef4634d45 Refactor out unmodified save_steps and eval_steps 2023-06-09 01:23:13 +09:00
NanoCode012
7eae90333e Merge pull request #166 from NanoCode012/fix/seed
Fix: Set to use cfg.seed or 42 for seed
2023-06-09 01:15:08 +09:00
NanoCode012
c8242de725 Merge pull request #132 from utensil/falcon-7b-qlora
Axolotl supports falcon + qlora
2023-06-09 01:14:03 +09:00
NanoCode012
2cfe9e9b16 Set to use cfg.seed or 42 for backward compat 2023-06-09 01:02:36 +09:00
Utensil
79a8f52181 Trim trailing whitespace 2023-06-08 23:48:57 +08:00
NanoCode012
afaa0d2c01 Merge pull request #164 from NanoCode012/fix/falcon-fsdp-validate
Fix: Validate falcon with fsdp
2023-06-09 00:44:12 +09:00
NanoCode012
bfd27ba55e Fix failing test 2023-06-09 00:35:03 +09:00
NanoCode012
babf0fdb71 Validate falcon with fsdp 2023-06-09 00:29:04 +09:00
Utensil
a52f4816b0 Default wandb_project to empty as suggested
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-08 23:04:19 +08:00
NanoCode012
81911d112c Merge pull request #163 from NanoCode012/feat/matmul-tf32
Feat: Set matmul tf32=True when tf32 passed
2023-06-09 00:01:31 +09:00
NanoCode012
52765ac588 Set matmul tf32 2023-06-08 23:41:12 +09:00
NanoCode012
73e9ea4069 Merge pull request #143 from NanoCode012/fix/deprecate-prepare-8bit-training
Fix future deprecate prepare_model_for_int8_training
2023-06-08 23:07:53 +09:00
NanoCode012
f8d379883d Merge pull request #162 from NanoCode012/fix/custom-prompt-readme
Fix: Move custom prompts out of hidden
2023-06-08 23:05:17 +09:00
NanoCode012
04a1b77307 Merge pull request #161 from NanoCode012/fix/peft-setup
Fix: Update peft and gptq instruction
2023-06-08 23:01:53 +09:00
NanoCode012
2097a09d2d Move custom prompts out of hidden 2023-06-08 22:53:56 +09:00
NanoCode012
cfff94b123 Add peft install for quickstart 2023-06-08 22:50:20 +09:00
NanoCode012
2b222de5b6 Update peft and gptq instruction 2023-06-08 22:48:26 +09:00
NanoCode012
df9528f865 Fix future deprecate prepare_model_for_int8_training 2023-06-08 21:42:10 +09:00
Angainor Development
193c73bce0 Fix training over existing lora
When training with Lora, and starting with an existing lora weights, current code produces a model with 0 trainable params and training can't work.
Adding the "is_trainable" param allows the loaded peft to be trained and fixes the bug.
2023-06-08 09:18:58 +02:00
Wing Lian
6abfd87d44 Merge pull request #158 from OpenAccess-AI-Collective/prompter-fixes
fix camel ai, add guanaco/oasst mapping for sharegpt
2023-06-07 11:02:30 -04:00
Wing Lian
59bb2197ed fix camel ai, add guanaco/oasst mapping for sharegpt 2023-06-07 09:51:29 -04:00
Wing Lian
9a02e7e1ff Merge pull request #155 from OpenAccess-AI-Collective/misc-fixes
new prompters, misc fixes for output dir missing using fsdp, and changing max seq len
2023-06-06 16:52:39 -04:00
Wing Lian
5b33e295bd update docs 2023-06-05 22:48:16 -04:00
Wing Lian
4ac9e251b7 new prompters, misc fixes for output dir missing using fsdp, and changing max seq len 2023-06-05 22:41:00 -04:00
Utensil
c9c050316f Default micro_batch_size to 1 for a safer start 2023-06-03 17:26:33 +08:00
Utensil
ca11ae9689 Add comments/alternatives for falcon-qlora configs 2023-06-03 15:04:02 +08:00
Wing Lian
328c3bce96 Merge pull request #149 from OpenAccess-AI-Collective/docker-clone-axolotl
clone in docker
2023-06-02 15:15:30 -04:00
Wing Lian
5cd2126439 shallow clone 2023-06-02 14:54:28 -04:00
Wing Lian
12620f3089 clone in docker 2023-06-02 14:52:50 -04:00
Wing Lian
4ab0c8b201 Merge pull request #148 from OpenAccess-AI-Collective/fix-device-load 2023-06-02 14:37:17 -04:00
Wing Lian
74ebbf4371 fix device map 2023-06-02 14:29:08 -04:00
Wing Lian
76a70fd739 Merge pull request #147 from OpenAccess-AI-Collective/winglian-rocker-images
Update README.md for correct image tags
2023-06-02 14:10:40 -04:00
Wing Lian
618816d4df Update README.md for correct image tags 2023-06-02 14:10:23 -04:00
Wing Lian
91992cb8f5 Merge pull request #146 from FarisHijazi/main
added docker-compose file
2023-06-02 13:58:23 -04:00
FarisHijazi
84169d15b3 added docker-compose file 2023-06-02 18:17:43 +03:00
Wing Lian
ecfe8d0a1a Merge pull request #142 from NanoCode012/feat/custom-prompt-readme
Feat: Add custom prompt readme and add missing prompt strategies to Readme
2023-06-02 07:21:04 -04:00
Wing Lian
eee44a3b47 Merge pull request #141 from NanoCode012/feat/lambdalabs-readme
Feat: Add lambdalabs instruction
2023-06-02 07:20:12 -04:00
NanoCode012
078a43eef8 Remove redundant instruction 2023-06-02 12:30:11 +09:00
NanoCode012
33e1890086 Add pygmalion 2023-06-02 12:27:51 +09:00
NanoCode012
1c38253692 Add other prompt_strategies 2023-06-02 12:24:44 +09:00
NanoCode012
496b83f778 Add short instruction for custom prompts 2023-06-02 12:16:20 +09:00
NanoCode012
ff68a95781 Add lambdalabs instruction 2023-06-02 12:09:40 +09:00
Utensil
fb3d40f197 falcon + qlora + xformer mbs 40 gas 2 on A6000 2023-06-01 18:29:20 +08:00
NanoCode012
288fd62431 Merge pull request #135 from NanoCode012/fix/grad-accu-readme
Fix: Update doc for grad_accu and add validation tests for batch size
2023-06-01 06:33:05 +09:00
NanoCode012
3c71c8debe Update doc for grad_accu and add validation tests for batch size 2023-06-01 06:13:47 +09:00
Wing Lian
a6f5e5eaec Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix
fix batch size calculation
2023-05-31 14:24:48 -04:00
Wing Lian
5a631b305b fix batch size calculation 2023-05-31 14:11:32 -04:00
Wing Lian
f94dd626f0 Merge pull request #130 from OpenAccess-AI-Collective/gas
swap batch size for gradient accumulation steps to decouple from num gpu
2023-05-31 13:03:51 -04:00
Wing Lian
5079753b7a Merge pull request #131 from OpenAccess-AI-Collective/fix-packing-mask
fix packing so that concatenated sequences reset the attention
2023-05-31 13:03:37 -04:00
Wing Lian
0136f510f2 don't worry about duplicate code here 2023-05-31 12:05:43 -04:00
Utensil
72bf8aafb6 Create config-7b-qlora.yml 2023-06-01 00:00:37 +08:00
Utensil
8afb0fbaba Axolotl supports falcon + qlora 2023-05-31 23:58:40 +08:00
Wing Lian
9b8585dc70 fix packing so that concatenated sequences reset the attention 2023-05-31 11:38:52 -04:00
Wing Lian
8eb5811d4e Merge pull request #129 from OpenAccess-AI-Collective/builder-badge
add badge info to readme
2023-05-31 10:37:59 -04:00
Wing Lian
e0011fdf55 Fix base builder, missing tags 2023-05-31 09:52:03 -04:00
Wing Lian
6e9e98720e Merge pull request #127 from OpenAccess-AI-Collective/py310-docker-runpod
add py310 support from base image
2023-05-31 09:39:42 -04:00
Wing Lian
c2a0792680 swap batch size for gradient accumulation steps to decouple from num gpu 2023-05-31 09:38:12 -04:00
Wing Lian
b267d24a2b add badge info to readme 2023-05-31 09:28:44 -04:00
Wing Lian
5c3f5db38b Add files via upload 2023-05-31 09:22:54 -04:00
Wing Lian
e3d03745ba add py310 support from base image 2023-05-31 09:07:28 -04:00
NanoCode012
fac46002d4 Merge pull request #119 from NanoCode012/feat/update-inference
Feat(inference): Swap to GenerationConfig
2023-05-31 14:09:18 +09:00
NanoCode012
33d40179ba Increase max_new_tokens
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-05-31 14:04:49 +09:00
Wing Lian
dcb03d6da4 Merge pull request #114 from OpenAccess-AI-Collective/accelerate-dep
Add accelerate dep
2023-05-31 00:47:17 -04:00
NanoCode012
0e4be625ae Merge pull request #118 from NanoCode012/feat/torch-readme
Fix(readme): Fix torch missing from readme
2023-05-31 13:29:41 +09:00
NanoCode012
bdc4bd7d4e Update README.md 2023-05-31 13:24:28 +09:00
Wing Lian
2d0ba3b818 Merge pull request #124 from OpenAccess-AI-Collective/xformers-fix
copy xformers attn from ooba since we removed dep on alpaca_lora_4bit
2023-05-31 00:11:40 -04:00
Wing Lian
c7021e191f Merge pull request #120 from OpenAccess-AI-Collective/model-from-path
split up llama model loading so config can be loaded from base config and models can be loaded from a path
2023-05-31 00:08:38 -04:00
Wing Lian
c56818b119 don't worry about dupes 2023-05-31 00:06:47 -04:00
Wing Lian
2675fb756e update readme for SDP 2023-05-31 00:04:54 -04:00
Wing Lian
1076bcbbca Update src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-05-31 00:00:19 -04:00
Wing Lian
2daa6835f0 Update src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-05-30 23:59:05 -04:00
Wing Lian
e3c494ca7b remove unused import and update readme 2023-05-30 23:55:45 -04:00
Wing Lian
ad0ea6aaab black formatting
ignore copied file
fix linting
2023-05-30 23:50:29 -04:00
Wing Lian
876edd83d0 Merge pull request #123 from OpenAccess-AI-Collective/bas-batch
add support for gradient accumulation steps
2023-05-30 23:45:29 -04:00
Wing Lian
6cb2310592 copy xformers attn from ooba since we removed dep on alpaca_lora_4bit 2023-05-30 23:34:36 -04:00
Wing Lian
6fa40bf8ad black formatting 2023-05-30 23:33:37 -04:00
Wing Lian
3aad5f3b3e add support for gradient accumulation steps 2023-05-30 23:24:37 -04:00
Wing Lian
39a208c2bc fix up tokenizer config, isort fix 2023-05-30 23:00:02 -04:00
Wing Lian
2520ecd6df split up llama model loading so config can be loaded from base config and models can be loaded from a path 2023-05-30 22:32:44 -04:00
Wing Lian
c5b0af1a7e define python version (3.10) explicitly as string in yaml 2023-05-30 22:23:35 -04:00
NanoCode012
988aeb9c34 Feat: Swap to GenerationConfig 2023-05-31 10:48:19 +09:00
NanoCode012
cf61f14bff FIx(readme): Fix torch missing from readme 2023-05-31 10:28:49 +09:00
Wing Lian
0abcd71a85 Merge pull request #115 from OpenAccess-AI-Collective/docker-version-fixes
docker fixes: py310, fix cuda arg in deepspeed
2023-05-30 18:11:26 -04:00
Wing Lian
c43c5c84ff py310, fix cuda arg in deepspeed 2023-05-30 18:02:34 -04:00
Wing Lian
36ec6e1a0e Add accelerate dep 2023-05-30 16:36:13 -04:00
123 changed files with 9144 additions and 1644 deletions

129
.github/CODE_OF_CONDUCT.md vendored Normal file
View File

@@ -0,0 +1,129 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement on Discord
at https://discord.gg/QYF8QrtEUm
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

76
.github/CONTRIBUTING.md vendored Normal file
View File

@@ -0,0 +1,76 @@
# Contributing to axolotl
First of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible.
## Table of Contents
- [Code of Conduct](#code-of-conduct)
- [Getting Started](#getting-started)
- [How to Contribute](#how-to-contribute)
- [Reporting Bugs](#reporting-bugs)
- [Suggesting Enhancements](#suggesting-enhancements)
- [Submitting Pull Requests](#submitting-pull-requests)
- [Style Guidelines](#style-guidelines)
- [Code Style](#code-style)
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conductcode
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.
## Getting Started
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!
1. Fork the repository and clone it to your local machine.
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
3. Explore the codebase, run tests, and verify that everything works as expected.
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
```
## How to Contribute
### Reporting Bugs
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
### Suggesting Enhancements
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
### Submitting Pull Requests
1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`.
2. Make your changes, following the [Style Guidelines](#style-guidelines) below.
3. Test your changes and ensure that they don't introduce new issues or break existing functionality.
4. Commit your changes, following the [commit message guidelines](#commit-messages).
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
## Style Guidelines
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
### Commit Messages
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
## Additional Resources
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [{codestyle}]({URLofCodestyle})
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

13
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,13 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

105
.github/ISSUE_TEMPLATE/bug-report.yaml vendored Normal file
View File

@@ -0,0 +1,105 @@
name: Bug Report
description: File a bug report
labels: ["bug", "needs triage"]
body:
- type: markdown
attributes:
value: |
## Before you start
Please **make sure you are on the latest version.**
If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**.
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "Please check that this issue hasn't been reported before."
description: "The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
required: true
- type: textarea
id: expected
attributes:
label: Expected Behavior
description: Tell us what **should** happen.
validations:
required: true
- type: textarea
id: what-happened
attributes:
label: Current behaviour
description: |
Tell us what happens instead of the expected behavior.
Provide stacktrace and/or screenshots.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: |
Which exact steps can a developer take to reproduce the issue?
The more detail you provide, the easier it will be to narrow down and fix the bug.
Please paste in tasks and/or queries **as text, not screenshots**.
placeholder: |
Example of the level of detail needed to reproduce any bugs efficiently and reliably.
1. Go to the '...' page.
2. Click on the '...' button.
3. Scroll down to '...'.
4. Observe the error.
validations:
required: true
- type: textarea
id: possible-solution
attributes:
label: Possible solution
description: |
Not obligatory, but please suggest a fix or reason for the bug, if you have an idea.
- type: checkboxes
id: operating-systems
attributes:
label: Which Operating Systems are you using?
description: You may select more than one.
options:
- label: Linux
- label: macOS
- label: Windows
- type: input
id: Python-version
attributes:
label: Python Version
description: Which {Programming} version are you using?
placeholder: 3.10 / please change accordingly
validations:
required: true
- type: input
id: axolotl-branch-commit
attributes:
label: axolotl branch-commit
description: On which branch/commit are you?
placeholder: main/4d6490b
validations:
required: true
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this bug has not been reported yet.'
required: true
- label: 'I am using the latest version of axolotl.'
required: true
- label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.'
required: true

7
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,7 @@
blank_issues_enabled: false
contact_links:
- name: Ask a question
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
about: Ask questions and discuss with other community members
- name: Discuss the Project in Discord
url: https://discord.gg/HhrNrHJPRb

46
.github/ISSUE_TEMPLATE/docs.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Documentation Improvement / Clarity
description: Make a suggestion to improve the project documentation.
labels: ['needs triage', 'docs']
body:
- type: markdown
attributes:
value: '## :book: Documentation :book:'
- type: markdown
attributes:
value: |
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
- type: textarea
attributes:
label: What piece of documentation is affected?
description: Please link to the article you'd like to see updated.
validations:
required: true
- type: textarea
attributes:
label: What part(s) of the article would you like to see updated?
description: |
- Give as much detail as you can to help us understand the change you want to see.
- Why should the docs be changed? What use cases does it support?
- What is the expected outcome?
validations:
required: true
- type: textarea
attributes:
label: Additional Information
description: Add any other context or screenshots about the feature request here.
validations:
required: false
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -0,0 +1,63 @@
name: Feature Request / Enhancement
description: Suggest a new feature or feature enhancement for the project
labels: ["enhancement", "needs triage"]
body:
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "⚠️ Please check that this feature request hasn't been suggested before."
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
required: true
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
required: true
- type: textarea
id: feature-description
validations:
required: true
attributes:
label: "🔖 Feature description"
description: "A clear and concise description of what the feature request is."
placeholder: "You should add ..."
- type: textarea
id: solution
validations:
required: true
attributes:
label: "✔️ Solution"
description: "A clear and concise description of what you want to happen, and why."
placeholder: "In my use-case, ..."
- type: textarea
id: alternatives
validations:
required: false
attributes:
label: "❓ Alternatives"
description: "A clear and concise description of any alternative solutions or features you've considered."
placeholder: "I have considered ..."
- type: textarea
id: additional-context
validations:
required: false
attributes:
label: "📝 Additional Context"
description: "Add any other context or screenshots about the feature request here."
placeholder: "..."
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -0,0 +1,22 @@
<!--- Provide a general summary of your changes in the Title above -->
# Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
## How has this been tested?
<!--- Please describe in detail how you tested your changes. -->
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## Screenshots (if appropriate)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->

9
.github/SECURITY.md vendored Normal file
View File

@@ -0,0 +1,9 @@
# Security Policy
## Supported Versions
Due to the nature of the fast development that is happening in this project, only the latest released version can be supported.
## Reporting a Vulnerability
If you find a vulnerability, please contact us on [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others.

10
.github/SUPPORT.md vendored Normal file
View File

@@ -0,0 +1,10 @@
# Support
If you need help with this project or have questions, please:
1. Check the documentation.
2. Search the existing issues and pull requests.
3. Create a new issue if your question is not answered or your problem is not solved.
4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb)
Please note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner.

View File

@@ -12,19 +12,19 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted
strategy:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
axolotl_extras:
- cuda: "117"
cuda_version: 11.7.0
pytorch: 1.13.1
axolotl_extras:
python_version: "3.9"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
pytorch: 2.0.0
axolotl_extras: gptq
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -46,12 +46,11 @@ jobs:
context: .
file: ./docker/Dockerfile-base
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}

View File

@@ -4,26 +4,24 @@ on:
push:
branches:
- "main"
- "dev"
jobs:
build-axolotl:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
pytorch: 2.0.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
pytorch: 2.0.0
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.0
pytorch: 1.13.1
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
runs-on: self-hosted
steps:
@@ -46,13 +44,12 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-axolotl-runpod:
needs: build-axolotl
if: github.repository_owner == 'OpenAccess-AI-Collective'
@@ -60,18 +57,17 @@ jobs:
strategy:
matrix:
include:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
pytorch: 2.0.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
pytorch: 2.0.0
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.0
pytorch: 1.13.1
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
is_latest: true
runs-on: self-hosted
steps:
- name: Checkout
@@ -93,10 +89,11 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-runpod
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -7,6 +7,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.9", "3.10"]
timeout-minutes: 10

View File

@@ -5,6 +5,9 @@ exclude = venv
[mypy-alpaca_lora_4bit.*]
ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-flash_attn.*]
ignore_missing_imports = True
@@ -31,3 +34,6 @@ ignore_missing_imports = True
[mypy-addict]
ignore_missing_imports = True
[mypy-xformers.*]
ignore_missing_imports = True

View File

@@ -1,5 +1,5 @@
default_language_version:
python: python3.9
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks

View File

@@ -2,3 +2,6 @@
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.

202
LICENSE Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

492
README.md
View File

@@ -1,10 +1,40 @@
# Axolotl
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
<table>
<tr>
<td>
## Table of Contents
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Installation](#installation)
- [Docker Installation](#environment)
- [Conda/Pip venv Installation](#condapip-venv)
- [LambdaLabs Installation](#lambdalabs)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-)
- [Need Help?](#need-help-)
- [Badge](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing](#contributing-)
</td>
<td>
<div align="center">
<img src="image/axolotl.png" alt="axolotl" width="160">
<div>
<p>
<b>One repo to finetune them all! </b>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p>
<p>
Go ahead and axolotl questions!!
@@ -14,33 +44,41 @@
</div>
</div>
</td>
</tr>
</table>
## Axolotl supports
| | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
|---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | ❌ | ❌ | ❌ | ❌ | ❓ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
## Quickstart ⚡
**Requirements**: Python 3.9.
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install -e .
accelerate config
pip3 install -e .[flash-attn]
pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
# inference
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
--inference --lora_model_dir="./lora-out"
```
@@ -50,28 +88,91 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
- Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
```
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
Or run on the current files for development:
```sh
docker compose up -d
```
- `winglian/axolotl:dev`: dev branch
- `winglian/axolotl-runpod:main`: for runpod
- Conda/Pip venv
1. Install python **3.9**
1. Install python >=**3.9**
2. Install python dependencies with ONE of the following:
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
- `pip3 install -e .[gptq_triton]`
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following:
- Recommended, supports QLoRA, NO gptq/int4 support
```bash
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
```
- gptq/int4 support, NO QLoRA
```bash
pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
```
- LambdaLabs
<details>
<summary>Click to Expand</summary>
1. Install python
```bash
sudo apt update
sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.9
```
2. Install pip
```bash
wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
```
3. Install torch
```bash
pip3 install -U torch --index-url https://download.pytorch.org/whl/cu118
```
4. Axolotl
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install -e . # change depend on needs
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
```
5. Set path
```bash
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
```
</details>
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -112,16 +213,99 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"article": "...", "summary": "..."}
```
> Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch!
- `alpaca_chat`: basic instruct for alpaca chat
```json
{"instruction": "...", "input": "...", "response": "..."}
```
- `alpaca_chat.load_qa`: question and answer for alpaca chat
```json
{"question": "...", "answer": "..."}
```
- `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers
```json
{"instruction": "...", "input": "...", "response": "..."}
```
- `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai
```json
{"message_1": "...", "message_2": "..."}
```
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
```json
{"system_prompt": "...", "question": "...", "response": "..."}
```
- `context_qa`: in context question answering from an article
```json
{"article": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json
{"article": "...", "unanswerable_question": "..."}
```
- `creative_acr.load_answer`: instruction and revision
```json
{"instruction": "...", "revision": "..."}
```
- `creative_acr.load_critique`: critique
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "..."}
```
- `creative_acr.load_revise`: critique and revise
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
```
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
</details>
Optionally, download some datasets, see [data/README.md](data/README.md)
#### How to add custom prompts
Using yaml. Example:
```yaml
datasets:
- path: repo
type:
system_prompt: ""
no_input_format: |-
User: {instruction}<|end_of_turn|>
Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
### Config
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
- model
```yaml
@@ -131,10 +315,33 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
- dataset
```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets:
- path: vicgalle/alpaca-gpt4 # local or huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
sequence_len: 2048 # max token length / prompt
# huggingface repo with specific configuration/subset
datasets:
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# local
datasets:
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
```
- loading
@@ -144,6 +351,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
bf16: true # require >=ampere
fp16: true
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
Note: Repo does not do 4-bit quantization.
@@ -171,12 +380,24 @@ base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
model_type: AutoModelForCausalLM
# Corresponding tokenizer for the model AutoTokenizer is a good choice
tokenizer_type: AutoTokenizer
# Trust remote code for untrusted source
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -195,24 +416,53 @@ fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
float16: true
# a list of one or more datasets to finetune the model with
datasets:
# this can be either a hf dataset, or relative path
# hf dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format OR format:prompt_style (chat/instruct)
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
data_files: # path to source data files
shards: # number of shards to split data into
name: # name of dataset configuration to load
# custom user prompt
- path: repo
type:
# the below are defaults. only set what's needed.
system_prompt: ""
field_system: system
field_instruction: instruction
field_output: input
# customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input}
format: |-
User: {instruction} {input}
Assistant:
# 'no_input_format' cannot include {input}
no_input_format: "{instruction} "
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04
# Num shards for whole dataset
dataset_shard_num:
@@ -222,9 +472,19 @@ dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
# you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values.
sample_packing_eff_est:
total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
@@ -249,31 +509,49 @@ lora_modules_to_save:
lora_out_dir:
lora_fan_in_fan_out: false
# ReLoRA configuration
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # number of steps per ReLoRA restart
relora_warmup_steps: # number of per-restart warmup steps
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
wandb_mode:
wandb_project:
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # your wandb project name
wandb_entity: # a wandb Team name if using a Team
wandb_watch:
wandb_run_id:
wandb_log_model: # 'checkpoint'
wandb_run_id: # set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# where to save the finished model to
output_dir: ./completed-model
# training hyperparameters
batch_size: 8
gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
save_strategy: # set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch
eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)
# group similarly sized data to minimize padding
# may be slower to start, as it must download and sort the entire dataset
# note that training loss may have an oscillating pattern with this enabled
group_by_length: false
# does not work with current implementation of 4-bit LoRA
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
@@ -295,11 +573,31 @@ log_sweep_max_lr:
optimizer:
# specify weight decay
weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# whether to bettertransformers
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
flash_attention: # require a100 for llama
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# resume from a specific checkpoint dir
resume_from_checkpoint:
@@ -322,14 +620,14 @@ tokens:
fsdp:
fsdp_config:
# Deepspeed
# Deepspeed config path
deepspeed:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
# Set padding for data collator to 'longest'
collator_pad_to_longest:
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
pretraining_dataset:
# Debug mode
debug:
@@ -343,22 +641,48 @@ strict:
</details>
### Accelerate
Configure accelerate
```bash
accelerate config
# Edit manually
# nano ~/.cache/huggingface/accelerate/default_config.yaml
```
### Train
Run
```bash
accelerate launch scripts/finetune.py configs/your_config.yml
accelerate launch scripts/finetune.py your_config.yml
```
#### Multi-GPU
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
```
##### Config
- llama FSDP
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
- llama Deepspeed
```yaml
deepspeed: deepspeed/zero3.json
```
##### Weights & Biases Logging
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
```
### Inference
@@ -367,11 +691,16 @@ Pass the appropriate flag to the train command:
- Pretrained LORA:
```bash
--inference --lora_model_dir ./completed-model
--inference --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
--inference --base_model ./completed-model
--inference --base_model="./completed-model"
```
- Full weights finetune w/ a prompt from a text file:
```bash
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
```
### Merge LORA to base
@@ -382,15 +711,28 @@ Add below flag to train command above
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
```
If you run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
```
## Common Errors 🧰
> Cuda out of memory
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below
- `micro_batch_size`
- `eval_batch_size`
- `gradient_accumulation_steps`
- `sequence_len`
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
> RuntimeError: expected scalar type Float but found Half
Try set `fp16: true`
@@ -399,13 +741,41 @@ Try set `fp16: true`
Try to turn off xformers.
## Need help? 🙋‍♂️
> accelerate config missing
It's safe to ignore it.
## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
```
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
## Community Showcase
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
PocketDoc Labs
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
## Contributing 🤝
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!

View File

@@ -1,15 +0,0 @@
compute_environment: LOCAL_MACHINE
distributed_type: 'NO'
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,40 +0,0 @@
base_model: cerebras/Cerebras-GPT-1.3B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- c_attn
lora_fan_in_fan_out: false
wandb_project: pythia-1.4b-lora
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-alpaca
batch_size: 32
micro_batch_size: 4
num_epochs: 5
learning_rate: 0.0003
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,41 +0,0 @@
base_model: facebook/galactica-1.3b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
adapter:
lora_model_dir:
sequence_len: 1024
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-llama-alpaca
batch_size: 32
micro_batch_size: 16
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: false
tf32: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
tokens:
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,39 +0,0 @@
base_model: huggyllama/llama-13b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt
dataset_prepared_path: last_run_prepared
val_set_size: 0.002
adapter:
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./llama-13b-sharegpt
batch_size: 64
micro_batch_size: 2
warmup_steps: 1000
save_steps:
eval_steps:
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience: 5
resume_from_checkpoint:
local_rank:

View File

@@ -1,44 +0,0 @@
base_model: huggyllama/llama-65b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-65b-lora
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-llama-alpaca
batch_size: 128
micro_batch_size: 16
warmup_steps: 1000
save_steps:
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,45 +0,0 @@
base_model: decapoda-research/llama-7b-hf-int4
base_model_config: decapoda-research/llama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: tatsu-lab/alpaca # original alpaca dataset
type: alpaca
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-test
batch_size: 8
micro_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
load_4bit: true
xformers_attention: true
flash_attention:

View File

@@ -1,41 +0,0 @@
base_model: huggyllama/llama-7b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-llama-alpaca
batch_size: 128
micro_batch_size: 16
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,45 +0,0 @@
base_model: decapoda-research/llama-7b-hf-int4
base_model_config: decapoda-research/llama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: tatsu-lab/alpaca # original alpaca dataset
type: alpaca
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 1024
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-test
batch_size: 4
micro_batch_size: 1
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
gptq: true
xformers_attention: true
flash_attention:

View File

@@ -1,86 +0,0 @@
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# this can also be a relative path to a model on disk
base_model: decapoda-research/llama-7b-hf-int4
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model
base_model_config: decapoda-research/llama-7b-hf
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
model_type: AutoModelForCausalLM
# Corresponding tokenizer for the model AutoTokenizer is a good choice
tokenizer_type: AutoTokenizer
# whether you are training a 4-bit quantized model
load_4bit: true
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
# a list of one or more datasets to finetune the model with
datasets:
# this can be either a hf dataset, or relative path
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
val_set_size: 0.04
# if you want to use lora, leave blank to train all parameters in original model
adapter: lora
# if you already have a lora model trained that you want to load, put that here
lora_model_dir:
# the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
max_packed_sequence_len: 1024
# lora hyperparameters
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
# wandb configuration if your're using it
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
# where to save the finsihed model to
output_dir: ./completed-model
# training hyperparameters
batch_size: 8
micro_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
# Use CUDA bf16
bf16: true
# Use CUDA tf32
tf32: true
# does not work with current implementation of 4-bit LoRA
gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
# specify a scheduler to use with the optimizer. only one_cycle is supported currently
lr_scheduler:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
flash_attention:
# resume from a specific checkpoint dir
resume_from_checkpoint:
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
# be careful with this being turned on between different models
auto_resume_from_checkpoints: false
# don't mess with this, it's here for accelerate and torchrun
local_rank:

View File

@@ -1,56 +0,0 @@
base_model: stabilityai/stablelm-base-alpha-3b
base_model_config: stabilityai/stablelm-base-alpha-3b
load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter:
lora_model_dir:
sequence_len: 4096
max_packed_sequence_len: 4096
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: stable-alpaca-3b
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./stable-alpaca-3b
batch_size: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100
eval_steps: 50
save_steps: 200
debug:
deepspeed:
weight_decay: 0.01
fsdp:
fsdp_config:
#tokens:
# pad_token: "[PAD]"
# bos_token: "<s>"
# eos_token: "</s>"
# unk_token: "<unk>"

View File

@@ -1,45 +0,0 @@
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_4bit: true
gptq_groupsize: 128
gptq_model_v1: false
datasets:
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
- path: data/alpaca_reflect_pruned.jsonl
type: reflection
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-reflect
batch_size: 8
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
local_rank:
flash_attention: true

View File

@@ -1,24 +0,0 @@
## Download some datasets
```shell
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
```
## Convert the JSON data files to JSONL.
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
---
Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
```shell
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
```

1
data/raw/.gitignore vendored
View File

@@ -1 +0,0 @@
**

46
deepspeed/zero2.json Normal file
View File

@@ -0,0 +1,46 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -37,18 +37,18 @@
"lr": "auto",
"betas": [
0.9,
0.999
0.95
],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "OneCycle",
"type": "WarmupLR",
"params": {
"cycle_min_lr": 0.00001,
"cycle_max_lr": 0.00003,
"cycle_first_step_size": 120
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"train_batch_size": "auto",

20
docker-compose.yaml Normal file
View File

@@ -0,0 +1,20 @@
# version: '3.8'
services:
axolotl:
build:
context: .
dockerfile: ./docker/Dockerfile
volumes:
- .:/workspace/axolotl
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
- WANDB_API_KEY=${WANDB_API_KEY}
deploy:
resources:
reservations:
devices:
- driver: nvidia
# count: 1
capabilities: [gpu]
command: tail -f /dev/null

View File

@@ -3,25 +3,27 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
RUN apt-get update && \
apt-get install -y vim curl
WORKDIR /workspace
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"transformers @ git+https://github.com/huggingface/transformers.git@main"
RUN mkdir axolotl
COPY . axolotl/
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[$AXOLOTL_EXTRAS]; \
pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
else \
pip install -e .; \
pip install -e .[flash-attn,gptq]; \
fi
# fix so that git fetch/pull from remote works
RUN cd axolotl && \
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -8,7 +8,7 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH="2.0.0"
ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -29,29 +29,12 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS flash-attn-builder
WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
cd flash-attention && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \
cd ../xentropy && \
python3 setup.py bdist_wheel && \
cd ../rotary && \
python3 setup.py bdist_wheel && \
cd ../layer_norm && \
python3 setup.py bdist_wheel
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS deepspeed-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
@@ -71,11 +54,14 @@ RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
@@ -84,15 +70,10 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN pip3 install wheels/deepspeed-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -1,6 +1,10 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
RUN apt install --yes --no-install-recommends openssh-server tmux && \

View File

@@ -0,0 +1,61 @@
base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false
load_in_4bit: true
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- c_fc
- c_attn
- c_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4
micro_batch_size: 4
num_epochs: 2
optimizer: paged_adamw_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 100000
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,22 @@
# Overview
This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
```

View File

@@ -23,7 +23,8 @@ lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project: falcon-7b
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -0,0 +1,93 @@
# 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: QingyiSi/Alpaca-CoT
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 3
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 5
save_steps: 10
debug:
deepspeed:
weight_decay: 0.000001
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -23,7 +23,8 @@ lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project: falcon-7b
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

58
examples/gptj/qlora.yml Normal file
View File

@@ -0,0 +1,58 @@
base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false
load_in_4bit: true
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 2
optimizer: paged_adamw_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"

View File

@@ -1,8 +0,0 @@
# LLaMa 7B using LoRA
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
```

View File

@@ -7,30 +7,29 @@ datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
sequence_len: 512
max_packed_sequence_len:
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: jeopardy-bot-7b
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
batch_size: 4
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
@@ -48,11 +47,10 @@ eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
weight_decay: 0.1
fsdp:
fsdp_config:
tokens:
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,20 @@
# Overview
This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
```shell
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
```

View File

@@ -0,0 +1,76 @@
base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_bits: 4
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- k_proj
- o_proj
- q_proj
- v_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000017
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100
eval_steps:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

67
examples/llama-2/lora.yml Normal file
View File

@@ -0,0 +1,67 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,73 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./relora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
relora_steps: 150
relora_warmup_steps: 10
relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,11 +20,12 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./mpt-alpaca-7b
batch_size: 1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit

View File

@@ -0,0 +1,16 @@
# openllama-3b
Basic full tune
```shell
accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
```
LoRA
```shell
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
```
QLoRA
```shell
accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
```

View File

@@ -1,62 +1,63 @@
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code:
load_in_8bit: true
gptq: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
datasets:
- path: vicgalle/alpaca-gpt4
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
sequence_len: 256
max_packed_sequence_len:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./llama-7b-lora-int4
batch_size: 1
output_dir: ./openllama-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
fp16: true
float16: true
bf16: false
tf32: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 5
xformers_attention:
logging_steps: 1
xformers_attention: true
flash_attention:
gradient_checkpointing: true
gptq_groupsize: 128
gptq_model_v1: false
warmup_steps: 20
eval_steps: 110
save_steps: 660
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps:
debug:
deepspeed:
weight_decay: 0.0001
weight_decay: 0.1
fsdp:
fsdp_config:
tokens:
pad_token: "[PAD]"
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,5 +1,5 @@
base_model: openlm-research/open_llama_3b_600bt_preview
base_model_config: openlm-research/open_llama_3b_600bt_preview
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
@@ -28,6 +28,7 @@ lora_target_modules:
- o_proj
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -49,7 +50,7 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:

View File

@@ -1,5 +1,5 @@
base_model: openlm-research/open_llama_3b_600bt_preview
base_model_config: openlm-research/open_llama_3b_600bt_preview
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
@@ -22,6 +22,7 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -34,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: true
group_by_length: false
bf16: true
fp16: false
tf32: true

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -1,39 +1,49 @@
base_model: EleutherAI/gpt-neox-20b
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: nomic-ai/gpt4all-j-prompt-generations
- path: vicgalle/alpaca-gpt4
type: alpaca
shards: 4
shards_index: 0
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_dropout: 0.0
lora_target_modules:
- query_key_value
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: gpt4all-neox-20b
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./gpt4all-neox-20b
batch_size: 48
micro_batch_size: 4
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
lr_scheduler: one_cycle
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:

View File

@@ -1,36 +1,30 @@
base_model: EleutherAI/pythia-1.4b-deduped
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:
sequence_len: 2048
lora_r: 8
sequence_len: 512
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- query_key_value
# - xxx
lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: pythia-1.4b-lora
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-alpaca
batch_size: 48
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 5
num_epochs: 3
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
@@ -39,3 +33,6 @@ tf32: True
early_stopping_patience:
resume_from_checkpoint:
local_rank:
weight_decay: 0.1
eval_steps: 20
logging_steps: 1

View File

@@ -1,6 +0,0 @@
# qlora-openllama-3b
```shell
accelerate launch scripts/finetune.py examples/qlora-openllama-3b/config.yml
```

View File

@@ -1,7 +1,7 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM
tokenizer_type: GPTNeoXTokenizer
tokenizer_type: AutoTokenizer
trust_remote_code:
load_in_8bit: false
datasets:
@@ -21,6 +21,7 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -20,6 +20,7 @@ lora_target_modules:
- mlp_down
lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -0,0 +1,91 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base
base_model_config: Salesforce/xgen-7b-8k-base
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: timdettmers/openassistant-guanaco
data_files:
- openassistant_best_replies_train.jsonl
type: "completion"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 8192
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 1
num_epochs: 3
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.00002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
eos_token: "<|endoftext|>"
bos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

BIN
image/axolotl-badge-web.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@@ -1,18 +1,31 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
addict
fire
PyYAML==6.0
PyYAML>=6.0
datasets
accelerate>=0.19.0
flash-attn>=2.0.8
sentencepiece
wandb
einops
xformers
optimum
hf_transfer
colorama
numba
numpy>=1.24.4
# qlora things
bert-score==0.3.13
evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
pynvml
art

View File

@@ -1,49 +0,0 @@
"""Module to convert json file to jsonl"""
import os
import sys
from pathlib import Path
from typing import Optional, Union
import fire
from axolotl.convert import (
FileReader,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
def main(
file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
"""
Convert a json file to jsonl
"""
file_reader = FileReader()
writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None:
writer = StdoutWriter()
else:
writer = FileWriter(output)
json_parser = JsonParser()
jsonl_serializer = JsonlSerializer()
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(file, output)
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -6,50 +6,62 @@ import os
import random
import signal
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import transformers
import yaml
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from art import text2art
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model, load_model_config, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
if torch.backends.mps.is_available():
return "mps"
debug: bool = field(default=False)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process():
print(ascii_art)
def get_multi_line_input() -> Optional[str]:
@@ -61,38 +73,72 @@ def get_multi_line_input() -> Optional[str]:
return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
if prompter == "None":
prompter = None
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=100,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -104,6 +150,10 @@ def choose_config(path: Path):
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
@@ -127,10 +177,141 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
if not (
cli_args.shard or cli_args.merge_lora or cli_args.inference
): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
)
if cli_args.prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if cli_args.merge_lora and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cli_args.inference:
LOG.debug("Running inference on model")
do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
return
if cli_args.shard:
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
@@ -142,140 +323,39 @@ def train(
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
cfg.gradient_accumulation_steps // cfg.world_size
)
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
normalize_config(cfg)
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
setup_wandb_env_vars(cfg)
return cfg
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
logging.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
def do_train(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if "merge_lora" in kwargs and cfg.adapter is not None:
logging.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
logging.info("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if "inference" in kwargs:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
logging.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
)
logging.info("Starting trainer...")
if cfg.group_by_length:
logging.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
train(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(train)
fire.Fire(do_train)

19
scripts/runpod-entrypoint.sh Normal file → Executable file
View File

@@ -1,10 +1,21 @@
#!/bin/bash
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc
# Start the SSH service in the background
service ssh start
if [[ $PUBLIC_KEY ]]
then
mkdir -p ~/.ssh
chmod 700 ~/.ssh
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background
service ssh start
else
echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon"
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -2,14 +2,27 @@
from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if r and r[0] != "#"]
for r in reqs:
install_requires.append(r)
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
]
for line in lines:
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif "flash-attn" not in line and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
@@ -18,15 +31,15 @@ setup(
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"gptq": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
"auto-gptq",
],
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
"flash-attn": [
"flash-attn==2.0.8",
],
"extras": [
"flash-attn",
"deepspeed",
],
},

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality"""
import logging
import os
from typing import List
import torch
from datasets import IterableDataset
from datasets import Dataset, IterableDataset
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
@@ -14,10 +15,12 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(IterableDataset):
class TokenizedPromptDataset(Dataset):
"""
Iterable dataset that returns tokenized prompts from a stream of text files.
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
@@ -27,18 +30,19 @@ class TokenizedPromptDataset(IterableDataset):
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset
super().__init__(self.process(dataset).data, **kwargs)
def __iter__(self):
iterator = iter(self.dataset)
# Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
except InvalidDataException:
pass
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
)
# TODO this isn't the best since it can't interleave datasets
@@ -72,14 +76,21 @@ class ConstantLengthDataset(IterableDataset):
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
@@ -101,6 +112,9 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
@@ -109,19 +123,23 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
logging.warning(
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
@@ -137,13 +155,17 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -1,139 +0,0 @@
"""Flash attention monkey patch for llama model"""
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
from typing import Optional, Tuple
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,70 @@
"""
Common logging module for axolotl
"""
import os
import sys
from logging import Formatter
from logging.config import dictConfig
from typing import Any, Dict
from colorama import Fore, Style, init
class ColorfulFormatter(Formatter):
"""
Formatter to add coloring to log messages by log type
"""
COLORS = {
"WARNING": Fore.YELLOW,
"ERROR": Fore.RED,
"CRITICAL": Fore.RED + Style.BRIGHT,
}
def format(self, record):
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"level": "DEBUG",
"propagate": False,
},
},
}
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)

View File

@@ -0,0 +1,598 @@
"""Flash attention monkey patch for llama model"""
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# flash-attn v2 start
#
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
if cu_seqlens is not None and max_seqlen is not None:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
"""
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -0,0 +1,140 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -0,0 +1,155 @@
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# xformers-attn start
#
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# xformers-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -0,0 +1,52 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
_expand_mask
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,393 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
self.resume_from_checkpoint = cfg.resume_from_checkpoint
if not os.path.exists(self.last_full_model):
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
assert os.path.exists(
self.last_full_model
), "for ReLORA base_model must be a local path"
self.num_lora_restarts = 0
self.need_full_save = False
def on_train_begin(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
if self.resume_from_checkpoint:
weight_path = os.path.join(self.resume_from_checkpoint, "relora")
if not os.path.exists(weight_path):
LOG.warning(
"Resuming ReLoRA from checkpoint, but no full-weight save found"
)
else:
LOG.info(f"Loading adjusted base weights from {weight_path}")
load_weight_checkpoint(model, weight_path)
return control
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"relora",
)
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
checkpoint_folder,
reinit=True,
quantized=self.quantized,
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(optimizer)
if self.quantized:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
)
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantized:
if is_main_process() and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
os.makedirs(checkpoint_folder, exist_ok=True)
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
) + glob.glob(f"{self.last_full_model}/model*.index.json")
for path in chunks:
new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
try:
os.symlink(new_path, path)
except OSError:
# probably on windows without permission to symlink
pass
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
return control
def on_log(
self,
_args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
logs: Dict[str, float],
**_kwargs,
):
logs["num_lora_restarts"] = self.num_lora_restarts
return control
def on_train_end(
self,
args: TrainingArguments,
_state: TrainerState,
control: TrainerControl,
model: peft.LoraModel,
**_kwargs,
):
if self.quantized:
# perform final merge and save
with torch.no_grad():
merge_and_save(
model,
self.last_full_model,
args.output_dir,
reinit=False,
quantized=self.quantized,
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
# no need to save if unquantized, as finetune.py will call merge_and_unload()
return control
class ReLoRAScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
str(Path(path) / f"{model_name}.index.json")
):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
if os.path.exists(index_path):
with open(index_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data["weight_map"]
return {(module_name + ".weight"): model_name for module_name in module_names}
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
@ layer.lora_A[adapter].weight.detach().to(device),
getattr(layer, "fan_in_fan_out", False),
)
* layer.scaling[adapter]
)
return layer.get_delta_weight().to(device)
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
for key in key_list:
try:
# pylint: disable=protected-access
_parent, target, _target_name = peft.utils._get_submodules(model.model, key)
except AttributeError:
continue
if isinstance(target, peft.tuners.lora.LoraLayer):
modules[key] = target
return modules
def update_weights(
target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
# when the module is moved from cpu to gpu. Without meddling *too* deeply in
# PEFT's innards or maintaining a duplicate of that codepath, this is good
# enough for now.
target.weight.quant_state = None
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
else:
target.weight.data = new_weight.to(device)
def merge_and_save(
model: peft.LoraModel,
model_src: str,
model_dst: str,
reinit: bool = False,
quantized: bool = False,
cpu_offload: bool = False,
actually_save: bool = True,
):
modules = find_lora_modules(model)
if not quantized:
for module_name, target in modules.items():
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
shard_paths = sharded_paths(model_src, modules.keys())
out_shard_paths = {}
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
out_tensors = {}
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]
for module_name, target in modules.items():
key = module_name + ".weight"
if key not in shard_paths or shard_paths[key] != shard_path:
continue
orig_weight = in_tensors[key]
old_dev = target.weight.device
math_dev = "cpu" if cpu_offload else old_dev
delta_weight = lora_delta_weight(target, math_dev)
new_weight = orig_weight.to(math_dev) + delta_weight
del delta_weight
if actually_save:
out_tensors[key] = new_weight.half().cpu()
update_weights(target, new_weight, reinit=reinit, device=old_dev)
if actually_save:
out_shard_name = shard_path
if out_shard_name.startswith("pytorch_model"):
out_shard_name = (
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
for module_name in in_tensors:
if module_name not in out_tensors:
out_tensors[module_name] = in_tensors[module_name].half()
out_shard_paths[module_name] = out_shard_name
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
del in_tensors
del out_tensors
torch.cuda.empty_cache()
if actually_save and len(unique_shards) > 1:
with open(
str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
) as file:
json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
modules = find_lora_modules(model)
shard_paths = sharded_paths(checkpoint_path, modules.keys())
unique_shards = list(set(shard_paths.values()))
for shard_path in unique_shards:
tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
for module_name, target in modules.items():
key = module_name + ".weight"
if key not in shard_paths or shard_paths[key] != shard_path:
continue
new_weight = tensors[key]
update_weights(
target, new_weight, reinit=False, device=target.weight.device
)

View File

@@ -0,0 +1,103 @@
"""
Shared utils for the monkeypatches
"""
import torch
def get_cu_seqlens(attn_mask):
"""generate a cumulative sequence length mask for flash attention using attn mask"""
if len(attn_mask.shape) == 1:
attn_mask = attn_mask.unsqueeze(0)
device = attn_mask.device
results = []
max_seq_lens = []
for row in attn_mask:
# Exclude zeros to avoid adding their positions to the mask
t_non_zeros = row[row != 0]
# Find where the sequence number changes (including the first position)
seq_change = torch.cat(
[
torch.tensor([1], dtype=torch.int32, device=device),
t_non_zeros[1:] != t_non_zeros[:-1],
]
)
# Get the indices where the sequence changes
change_indices = torch.cat(
[
(seq_change == 1).nonzero(as_tuple=True)[0],
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = len(row) - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
max_seq_lens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
(seq_starts).nonzero(as_tuple=True)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
# Append the padding length to the cumulative sequence lengths
if padding_length:
cu_seqlens = torch.cat(
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)

View File

@@ -0,0 +1,94 @@
# pylint: skip-file
"""
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
"""
import torch
import transformers
import transformers.models.llama.modeling_llama
from einops import rearrange
class XposRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scale_base=2048,
use_xpos=True,
):
super().__init__()
self.max_seq_len_cached = max_position_embeddings
self.scale_base = scale_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
freqs = torch.einsum("i , j -> i j", t, inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("freqs_cached", freqs, persistent=False)
if not use_xpos:
self.register_buffer("scale", None)
self.register_buffer("scale_cached", torch.ones(1))
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
scale_cached = scale ** rearrange(power, "n -> n 1")
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
self.register_buffer("scale", scale, persistent=False)
self.register_buffer("scale_cached", scale_cached, persistent=False)
def forward(
self,
x,
seq_len,
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
self.inv_freq
)
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
self.register_buffer("freqs_cached", freqs)
if self.scale is None:
self.register_buffer(
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
)
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, "n -> n 1")
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
self.register_buffer("scale_cached", scale)
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
freqs = freqs[position_ids, :]
if scale.shape[-1] != 1:
scale = scale[position_ids, :]
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
return q_embed, k_embed
def replace_llama_rope_with_xpos_rope():
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb

View File

@@ -2,8 +2,10 @@
import importlib
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
def load(strategy, tokenizer, cfg):
def load(strategy, tokenizer, cfg, ds_cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
func = getattr(mod, load_fn)
return func(tokenizer, cfg)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
return None

View File

@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg):
@@ -18,6 +18,42 @@ def load(tokenizer, cfg):
)
class AlpacaConcisePrompter(AlpacaPrompter):
"""
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()
class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""
system_prompt = ""
system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for AlpacaQA
@@ -31,9 +67,49 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
)
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for CamelAI datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["message_1"],
"",
prompt["message_2"],
)
def load_concise(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaConcisePrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy(
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -1,7 +1,7 @@
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg):
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -0,0 +1,163 @@
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
prompt["system"],
)
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
(
instruction,
input, # pylint: disable=redefined-builtin
response,
system,
) = self.parse_instruction_fields(prompt)
user_prompt = next(
iter(
self.prompter.build_prompt_w_system(
system,
instruction,
input,
)
)
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_prompt
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
system_format: str = "### System:\n{system}\n\n"
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
formatted_sys_prompt = (
self.system_format.format(system=system)
if system and self.system_format
else ""
)
if input:
res = formatted_sys_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else:
res = formatted_sys_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
yield res
class OpenOrcaSystemDataPrompter(SystemDataPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
"""
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n"
self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n"
self.system_format = "### System:\n{system}\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Tokenizing strategy for OpenOrca datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["question"],
"",
prompt["response"],
prompt["system_prompt"],
)
def load(tokenizer, cfg):
return load_chat(tokenizer, cfg)
def load_instruct(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_chat(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca_chatml(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -0,0 +1,67 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
# article, unanswerable_question, question, answer
def load_404(tokenizer, cfg):
return AlpacaMissingInfoContextPromptTokenizingStrategy(
AlpacaContextPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load(tokenizer, cfg):
return AlpacaContextPromptTokenizingStrategy(
AlpacaContextPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaContextPrompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
"""
system_prompt = (
"Use the following contextual information to concisely answer the question.\n"
)
system_no_input_prompt = (
"Use the following contextual information to concisely answer the question.\n"
)
class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenization Strategy to combine in-context article with a question and answer
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"] + "\n===\n" + prompt["question"],
"",
prompt["answer"],
)
class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy
):
"""
Tokenization Strategy to combine in-context article with a question that can't be answered
from the context and a default response to that effect
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
"",
"The context provided does not contain any information about your inquiry. "
"Therefore, I'm unable to answer your question based on the given context.",
)

View File

@@ -0,0 +1,205 @@
"""
Prompt Strategy for finetuning Llama2 chat models
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
This implementation is based on the Vicuna PR and the fastchat repo, see also:
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
E.g. in the config.yml:
```
datasets:
- path: llama_finetune_train.jsonl
type: llama2_chat
```
The dataset itself should look like this:
```
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
```
in a jsonl file. The first message should be from the human, the second from gpt.
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
"""
import logging
from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
@dataclass
class Llama2ChatConversation:
"""A class that manages prompt templates and keeps all conversation history.
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
name: str = "llama2"
# The system prompt
system: str = (
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
)
roles: Sequence[str] = ("[INST]", "[/INST]")
messages: List[List[str]] = field(default_factory=list)
offset: int = 0
sep = " "
sep2 = " </s><s>"
stop_token_ids = [2]
def get_prompt(self) -> str:
"""Get the prompt for generation."""
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if (i == len(self.messages) - 1) and (role == self.roles[0]):
# last message is from user (due to length),
# return prompt without it for training
return ret
if i == 0:
ret += self.system + message.strip()
else:
ret += role + " " + message.strip() + seps[i % 2]
return ret
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sequence_len = 4096
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt):
conv = next(self.prompter.build_prompt(prompt))
conversation_str = conv.get_prompt()
# Tokenize conversations
input_ids = self.tokenizer(
conversation_str,
return_tensors="pt",
padding="max_length",
max_length=self.sequence_len,
truncation=True,
).input_ids[0]
target = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.roles[1]
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
turns = conversation_str.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for turn in turns:
if turn == "":
break
turn_len = len(self.tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
# Ignore the user instructions
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len + 2 # due to length of role token
target[cur_len:] = IGNORE_TOKEN_ID
if cur_len < self.sequence_len:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
logging.warning(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
input_ids = input_ids.tolist()
target = target.tolist()
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
# follows the original llama implementation
for i in range(2, total_len - 2):
if input_ids[i] == 29961:
input_ids[i] = 518
if target[i] == 29961:
target[i] = 518
return {
"input_ids": input_ids,
"labels": target,
"attention_mask": attention_mask,
}
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for Llama2 models.
"""
system_prompt = (
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
)
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
source = source["conversations"] # fix data structure for datasets
# if system prompt provided, use it
if source[0]["from"] == "system":
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
source = source[1:]
else:
system = self.system_prompt
conv = Llama2ChatConversation(system=system)
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if sentence["value"]:
conv.append_message(role, sentence["value"])
yield conv
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
return LLama2ChatTokenizingStrategy(
Llama2ChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -0,0 +1,76 @@
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
import logging
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
# pylint: disable=duplicate-code
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for the Metharme models
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (prompt["prompt"], "", prompt["generation"])
def _tokenize(
self,
prompt: str,
add_eos_token: bool = True,
strip_bos_token: bool = False,
num_eos_tokens: int = 3,
):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
# If there's already an EOS token there, subtract from the number added
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
num_eos_tokens -= 1
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
for _ in range(num_eos_tokens):
if len(result["input_ids"]) < self.sequence_len:
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
class MetharmePrompter(AlpacaPrompter):
"""
Prompter for the Metharme models.
"""
system_prompt = ""
system_no_input_prompt = ""
system_format = ""
turn_format = "{instruction}"
turn_no_input_format = "{instruction}"
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
pass
def load(tokenizer, cfg):
return MetharmePromptTokenizingStrategy(
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)

View File

@@ -0,0 +1,46 @@
"""
Prompt Strategy for finetuning Orca Mini (v2) models
see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information
Use dataset type: orcamini in conig.yml to use this prompt style.
Compared to the alpaca_w_system.open_orca dataset type,
this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments.
"""
from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
class OrcaMiniPrompter(AlpacaPrompter):
"""Adjusted Prompter for Orca Mini (v2) datasets"""
def match_prompt_style(self):
self.turn_no_input_format = (
"### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n"
)
def build_prompt_w_system(
self,
system: str,
instruction: str,
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
res = self.turn_no_input_format.format(system=system, instruction=instruction)
if output:
res = f"{res}{output}"
yield res
def load(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OrcaMiniPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -11,6 +11,8 @@ from axolotl.prompt_tokenizers import (
tokenize_prompt_default,
)
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -64,7 +66,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
*copy.deepcopy(res["input_ids"])
][len(self.bot_prefix_token_ids) :]
else:
logging.warning(f"unknown role in conversation: {role}")
LOG.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: [])
# pylint: disable=duplicate-code

View File

@@ -0,0 +1,28 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
Tokenization strategy for asking bot to tell a joke and then explain why its funny
"""
# title, text, explanation
def get_conversation_thread(self, prompt):
title = "" if not prompt["title"] else prompt["title"] + " "
return [
{"from": "human", "value": "Tell me a joke."},
{"from": "gpt", "value": title + prompt["text"]},
{"from": "human", "value": "Why is that joke funny?"},
{"from": "gpt", "value": prompt["explanation"]},
]

View File

@@ -0,0 +1,67 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns

View File

@@ -0,0 +1,98 @@
"""
User Defined prompts with configuration from the YML config
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
@dataclass
class UserDefinedDatasetConfig:
"""
dataclass configuration representing a userdefined dataset type
"""
system_prompt: str = ""
field_system: str = "system"
field_instruction: str = "instruction"
field_input: str = "input"
field_output: str = "output"
format: str = "{instruction} {input} "
no_input_format: str = "{instruction} "
system_format: str = "{system}"
def __getitem__(self, item):
return getattr(self, item)
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Prompt Tokenization Strategy for user defined prompts
"""
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
if not ds_cfg:
raise ValueError("Missing dataset prompt configuration")
system_prompt = ""
if ds_cfg.system_prompt:
system_prompt = ds_cfg.system_prompt
def parse_instruction_fields(
field_instruction,
field_input,
field_output,
field_system,
system_prompt,
prompt,
) -> Tuple[str, str, str, str]:
return (
prompt[field_instruction],
prompt[field_input] if field_input in prompt else "",
prompt[field_output] if field_output in prompt else "",
prompt[field_system] if field_system in prompt else system_prompt,
)
turn_format = ds_cfg.format
turn_no_input_format = ds_cfg.no_input_format
system_format = ds_cfg.system_format
class UserDefinedPrompter(SystemDataPrompter):
"""
Prompter for user defined prompts
"""
def match_prompt_style(self):
self.turn_format = turn_format
self.turn_no_input_format = turn_no_input_format
self.system_format = system_format
prompter = UserDefinedPrompter()
strat = UserDefinedPromptTokenizationStrategy(
prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
setattr(
strat,
"parse_instruction_fields",
partial(
parse_instruction_fields,
ds_cfg.field_instruction,
ds_cfg.field_input,
ds_cfg.field_output,
ds_cfg.field_system,
system_prompt,
),
)
return strat

View File

@@ -10,8 +10,10 @@ from transformers import PreTrainedTokenizer
from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
@@ -46,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC):
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
@@ -66,15 +74,22 @@ class PromptTokenizingStrategy(abc.ABC):
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
@@ -87,7 +102,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
@@ -96,25 +113,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
input, # pylint: disable=redefined-builtin
response,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_full_prompt
return tokenized_prompt
def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin
@@ -380,7 +399,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
logging.warning(f"unhandled role: {part[0]}")
LOG.warning(f"unhandled role: {part[0]}")
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
@@ -436,7 +455,7 @@ def parse_tokenized_to_result(
result: Dict[str, List[int]],
current_len: int,
res: Dict[str, List[int]],
labels: list[int],
labels: List[int],
pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]:
"""

View File

@@ -5,6 +5,7 @@ import logging
from enum import Enum, auto
from typing import Generator, List, Optional, Tuple, Union
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -15,6 +16,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct"
CHAT = "chat"
CHATML = "chatml"
class AlpacaPrompter:
@@ -24,6 +26,9 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -31,24 +36,23 @@ class AlpacaPrompter:
self.match_prompt_style()
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
self.system_format = "### System:\n{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def build_prompt(
self,
@@ -59,16 +63,21 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input)
else:
res = self.prompt_no_input.format(instruction=instruction)
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class UnpromptedPrompter(AlpacaPrompter):
"""
@@ -93,7 +102,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
"""
system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning."
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
@@ -102,7 +114,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise
"""
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +127,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR
"""
prompt_no_input = (
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
)
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
@@ -128,9 +148,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]:
yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter):
"""
@@ -210,9 +227,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum):
"""Different separator style."""
@@ -243,7 +257,7 @@ class Conversation:
if message:
yield (role + ":", " " + message)
else:
logging.warning(f"role with empty message: {role}")
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")
def copy(self):
@@ -261,15 +275,8 @@ class Conversation:
self.messages.append([role, message])
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
@@ -278,17 +285,28 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
A prompter that generates prompts for the ShareGPT
"""
def __init__(self, prompt_style=None):
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
system: str = (
system_prompt
if system_prompt
else (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
self._conversation = Conversation(
system=system,
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
@@ -298,9 +316,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = conv_vicuna_v1_1.copy()
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try:
@@ -318,7 +338,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
conv.append_message(role, sentence["value"])
for part in conv.get_prompt():

View File

@@ -0,0 +1,43 @@
"""Benchmarking and measurement utilities"""
import pynvml
import torch
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

Some files were not shown because too many files have changed in this diff Show More