Compare commits

...

340 Commits

Author SHA1 Message Date
Wing Lian
81d60e96f0 multipack sampler support from openchat
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-15 08:01:33 -04:00
NanoCode012
168a7a09cc Merge pull request #274 from OpenAccess-AI-Collective/NanoCode012-patch-2
Feat: Set push to hub as private by default
2023-07-14 23:15:47 +09:00
NanoCode012
231031a0e1 Merge pull request #275 from NanoCode012/feat/safetensors
Feat: Add save_safetensors
2023-07-14 23:07:26 +09:00
NanoCode012
5daf7d5299 Merge pull request #273 from OpenAccess-AI-Collective/NanoCode012-patch-1
Feat(docs): Add model_revision arg
2023-07-14 21:09:50 +09:00
NanoCode012
5491278a79 Feat: Add save_safetensors 2023-07-14 13:21:47 +09:00
NanoCode012
1514739f0f Set push to hub as private by default 2023-07-14 13:17:49 +09:00
NanoCode012
896c1aebcf Feat(docs): Add model_revision arg 2023-07-14 12:56:07 +09:00
Wing Lian
ef17e15483 Merge pull request #272 from OpenAccess-AI-Collective/model-revision
support for loading a model by git revision
2023-07-13 23:12:00 -04:00
Wing Lian
69a235061b support for loading a model by git revision 2023-07-13 22:58:25 -04:00
Wing Lian
687d889928 Merge pull request #271 from OpenAccess-AI-Collective/quadratic-warmup
Quadratic warmup
2023-07-10 12:48:02 -04:00
Wing Lian
c4cf567b55 Merge branch 'main' into quadratic-warmup 2023-07-10 12:42:12 -04:00
Wing Lian
c49729d2bc better configuration for quadratic warmup 2023-07-10 11:52:59 -04:00
Wing Lian
13ac4d8de2 Merge pull request #268 from OpenAccess-AI-Collective/fix-adam-args
params are adam_*, not adamw_*
2023-07-08 12:33:34 -04:00
Wing Lian
19cf0bda99 params are adam_*, not adamw_* 2023-07-08 12:13:39 -04:00
Wing Lian
f74edd5b56 Merge pull request #266 from OpenAccess-AI-Collective/trust-remote-no-llama 2023-07-07 21:38:11 -04:00
Wing Lian
d69da99c2c skip explicit model type too if using trust_remote_code 2023-07-07 21:33:11 -04:00
Wing Lian
66afb76a15 don't use llama if trust_remote_code is set since that needs to use AutoModel path 2023-07-07 21:31:02 -04:00
NanoCode012
a692ad3f4c Merge pull request #264 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix(readme): local path loading and custom strategy type
2023-07-06 23:34:57 +09:00
NanoCode012
41da98b982 Fix for linter 2023-07-06 23:20:11 +09:00
NanoCode012
9e64f42e0f Fix local path loading and custom strategy type 2023-07-06 23:08:09 +09:00
Wing Lian
b9b7d4ce92 Merge pull request #221 from utensil/local_dataset
[WIP] Support loading data files from a local directory
2023-07-03 09:10:13 -04:00
Wing Lian
9bed281867 Merge pull request #258 from NanoCode012/fix/deprecate-push
Fix future deprecation push_to_hub_model_id
2023-07-03 09:08:26 -04:00
NanoCode012
e79c8e617e Fix future deprecation push_to_hub_model_id 2023-07-03 12:44:29 +09:00
Wing Lian
71456955f5 pin pydantic so deepspeed isn't broken 2023-07-02 22:26:51 -04:00
Wing Lian
3a783c04e4 Merge pull request #247 from OpenAccess-AI-Collective/fix-apex-base
update pip install command for apex
2023-07-01 06:18:25 -04:00
Wing Lian
1e5014acec Merge pull request #255 from OpenAccess-AI-Collective/open-orca-prompts
open orca support
2023-07-01 01:11:23 -04:00
Wing Lian
a10da1caff 11.7.0 nvidia/cuda docker images are deprecated, move to 11.7.1
Some checks failed
ci-cd-base / build-base (<nil>, 117, 11.7.1, 3.9, 1.13.1) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.10, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (gptq, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-01 00:29:07 -04:00
Wing Lian
4066c78631 Merge pull request #246 from OpenAccess-AI-Collective/sys-prompts-instruct
add option for instruct w sys prompts
2023-07-01 00:27:29 -04:00
Wing Lian
78a1e1fa12 open orca support 2023-07-01 00:19:41 -04:00
NanoCode012
bc8a2e5547 Merge pull request #249 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix typing list in prompt tokenizer
2023-06-30 15:01:41 +09:00
NanoCode012
910ebe47f5 Merge pull request #252 from OpenAccess-AI-Collective/NanoCode012-readme-fix
Add cfg.push_to_hub_model_id to readme
2023-06-30 14:56:55 +09:00
NanoCode012
c146880a75 Update README.md 2023-06-30 11:33:53 +09:00
NanoCode012
77bdb7d144 Fix typing list 2023-06-29 14:29:55 +09:00
Wing Lian
530809fd74 update pip install command for apex 2023-06-28 22:36:28 -04:00
Wing Lian
924bbfddec add option for instruct w sys prompts 2023-06-28 22:27:17 -04:00
Wing Lian
f150c027e3 Merge pull request #224 from OpenAccess-AI-Collective/system-prompt-data
System prompt data
2023-06-27 17:57:43 -04:00
Wing Lian
5c39c006c9 Merge pull request #244 from OpenAccess-AI-Collective/push-to-hub
push intermediate model checkpoints to hub
2023-06-27 17:57:30 -04:00
Wing Lian
612aabd8c4 push intermediate model checkpoints to hub 2023-06-27 15:40:25 -04:00
Wing Lian
af05883f75 Merge pull request #243 from OpenAccess-AI-Collective/unprompted-instruct
skip the system prompt
2023-06-25 22:50:35 -04:00
Wing Lian
05ab9092e3 skip the system prompt 2023-06-25 22:40:50 -04:00
Wing Lian
7b57ed7618 pylint for duplicated code for system prompts 2023-06-25 22:28:07 -04:00
Wing Lian
3a38271276 add tests and supoort for loader for sys prompt data 2023-06-25 22:28:07 -04:00
Wing Lian
8d20e0a3d3 initial wip to get sys prompt from dataset 2023-06-25 22:28:07 -04:00
Wing Lian
de8ed229c3 Merge pull request #240 from OpenAccess-AI-Collective/tokenizer-fast
optionally define whether to use_fast tokenizer
2023-06-25 12:47:55 -04:00
Wing Lian
478d8c7b8e Merge pull request #241 from OpenAccess-AI-Collective/py3-pre-commit
better py3 support w pre-commit
2023-06-25 12:47:02 -04:00
Wing Lian
645c13592c better py3 support w pre-commit 2023-06-25 10:26:02 -04:00
Wing Lian
47d601fa23 optionally define whether to use_fast tokenizer 2023-06-25 10:19:49 -04:00
Wing Lian
756dfba97b Merge pull request #218 from OpenAccess-AI-Collective/no-fail-fast
don't fail fast
2023-06-23 15:42:54 -04:00
Wing Lian
91ab0592af Merge pull request #235 from msinha251/Fixing-data-readme 2023-06-23 13:52:01 -04:00
Mahesh Sinha
0aeb7c7802 Fixing Data Readme 2023-06-21 15:34:48 +02:00
Utensil
9bdd30cdfd Support loading data files from a local directory
ref:  https://huggingface.co/docs/datasets/v2.13.0/en/package_reference/loading_methods#datasets.load_dataset.path
2023-06-21 08:00:58 +00:00
Wing Lian
d35278aaf1 don't fail fast 2023-06-15 16:01:27 -04:00
Wing Lian
9492d4ebb7 Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperparams-cfg
support adamw and grad norm hyperparams
2023-06-15 12:20:55 -04:00
Wing Lian
ad5ca4f734 Additional test case per pr 2023-06-15 10:12:47 -04:00
Wing Lian
cb9d3af5c0 add validation and tests for adamw hyperparam 2023-06-15 09:39:42 -04:00
Wing Lian
c969f0a9dc add docs 2023-06-15 08:43:20 -04:00
Wing Lian
6d0ee4ba34 support adamw and grad norm hyperparams 2023-06-15 08:40:41 -04:00
Wing Lian
a81f52d575 Merge pull request #212 from OpenAccess-AI-Collective/doc-20230615-v1
add float16 docs and tweak typehints
2023-06-15 08:28:57 -04:00
Wing Lian
1925eaf1e6 Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
2023-06-15 08:13:43 -04:00
Wing Lian
1ab3bf3e67 fix test name 2023-06-15 02:09:33 -04:00
Wing Lian
d7635b7148 hint to what AMP means 2023-06-15 02:06:27 -04:00
Wing Lian
88e17ffc50 add float16 docs and tweak typehints 2023-06-15 02:05:31 -04:00
Wing Lian
baed440fa1 ingore duplicate code in tests 2023-06-15 02:03:53 -04:00
Wing Lian
7925ddce86 bugfix for potential off by one 2023-06-15 01:59:33 -04:00
Wing Lian
6f849809c5 Merge pull request #206 from MaciejKarasek/issue205
issue #205 bugfix
2023-06-14 14:23:38 -04:00
Wing Lian
c16644d05e Merge pull request #209 from sroecker/fix_redpajama_example_tokenizer
Use AutoTokenizer for redpajama example
2023-06-14 14:23:21 -04:00
Steffen Röcker
945c4191a3 Use AutoTokenizer for redpajama example 2023-06-14 20:09:26 +02:00
maciej.karasek
136522f9c9 style correction 2023-06-14 20:02:09 +02:00
maciej.karasek
556fe408b3 issue #205 bugfix 2023-06-14 16:59:57 +02:00
Wing Lian
16bb6276a5 Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum
add support for opimum bettertransformers
2023-06-14 07:50:15 -04:00
NanoCode012
06674a11f2 Merge pull request #202 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix sharegpt type in doc
2023-06-14 09:48:35 +09:00
NanoCode012
3513885f43 Fix sharegpt type 2023-06-14 01:10:58 +09:00
Wing Lian
06652c1c39 Merge pull request #196 from OpenAccess-AI-Collective/openllama-ft-config
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
tweak config to work
2023-06-13 11:51:04 -04:00
NanoCode012
068fc48978 Merge pull request #199 from NanoCode012/chore/prompter-arg
chore: Refactor inf_kwargs out
2023-06-13 17:56:22 +09:00
Wing Lian
aaadacf6b3 Merge pull request #200 from PocketDocLabs/main
Update README.md to include a community showcase
2023-06-13 04:44:34 -04:00
PocketDoc Labs
5ff547dc70 Update README.md to include a community showcase 2023-06-12 22:38:10 -07:00
NanoCode012
dc77c8ebce chore: Refactor inf_kwargs out 2023-06-13 12:01:46 +09:00
NanoCode012
51a4c12242 Merge pull request #197 from mhenrichsen/chore/update-readme
chore: Fix inference README.
2023-06-13 11:53:26 +09:00
Wing Lian
4b43a66a0b update alpaca_chat prompts for instructions to explainn the conversation 2023-06-12 18:38:38 -04:00
mhenrichsen
34ae69989f fix inference 2023-06-12 21:39:19 +02:00
Wing Lian
7dc580b837 add axolotl trainer and quadratic warmup 2023-06-12 13:16:40 -04:00
Wing Lian
fd2c9814c9 Merge branch 'main' into flash-optimum 2023-06-12 13:12:15 -04:00
Wing Lian
2ba4ae8f46 tweak config to work 2023-06-12 10:07:18 -04:00
Wing Lian
93dacba228 Merge pull request #187 from OpenAccess-AI-Collective/strip-peft-device-map
peft no longer needs device_map
2023-06-12 09:10:49 -04:00
Wing Lian
8002ffb41f Merge pull request #177 from NanoCode012/fix/landmark-patch
Fix landmark attention patch
2023-06-12 08:27:12 -04:00
Wing Lian
74ef5cc083 Merge pull request #192 from OpenAccess-AI-Collective/sharegpt-custom-prompt
misc fixes
2023-06-12 08:26:38 -04:00
Wing Lian
5e616d91c0 Merge branch 'main' into strip-peft-device-map 2023-06-12 08:25:54 -04:00
Wing Lian
94f310c7a6 Merge pull request #193 from OpenAccess-AI-Collective/config-fixes-20230612
config fixes
2023-06-12 08:24:52 -04:00
NanoCode012
8e568bbdae Merge pull request #159 from AngainorDev/patch-1
Fix training over existing lora
2023-06-12 20:27:11 +09:00
NanoCode012
e21dab49fd Merge pull request #194 from NanoCode012/fix/config-path
Fix config path after config moved
2023-06-12 19:28:12 +09:00
NanoCode012
52cde69288 Fix config path after config moved 2023-06-12 17:06:15 +09:00
Wing Lian
9a58e99e81 config fixes 2023-06-12 01:52:58 -04:00
Wing Lian
c7dee56b87 add typehints 2023-06-11 19:52:34 -04:00
Wing Lian
aac4b7691e add new sharegpt, refactor prompt so it can be customized later, add exception if no data is processed 2023-06-11 19:42:25 -04:00
NanoCode012
f31a338cbb Merge pull request #191 from OpenAccess-AI-Collective/NanoCode012-patch-1
Add save_steps and eval_steps to Readme
2023-06-12 02:55:37 +09:00
NanoCode012
4cd1deeef2 Add save_steps and eval_steps to Readme 2023-06-12 02:44:46 +09:00
Wing Lian
9ac16ed8d1 Merge pull request #190 from OpenAccess-AI-Collective/fixes-20230711-v2
more config pruning and migrating
2023-06-11 13:27:08 -04:00
Wing Lian
6b3f509d9e forgot to add this file 2023-06-11 11:50:12 -04:00
Wing Lian
336aa3fd48 gptq lora llama is obviously good 2023-06-11 11:05:29 -04:00
Wing Lian
d0d7eaa4f3 update openllama and clean up paths 2023-06-11 11:03:31 -04:00
Wing Lian
a6ebf57e82 fix table formatting 2023-06-11 10:55:32 -04:00
Wing Lian
280832cec2 more matrix updates 2023-06-11 10:52:36 -04:00
Wing Lian
a43bae9ff0 update the support matrix 2023-06-11 10:44:03 -04:00
Wing Lian
effbbf6dd1 more pruning 2023-06-11 10:38:24 -04:00
Wing Lian
c9a149f9e8 add check for attr 2023-06-11 10:11:17 -04:00
Wing Lian
c530e4b9c8 more config pruning and migrating 2023-06-11 10:09:05 -04:00
Wing Lian
f620706776 Merge pull request #189 from OpenAccess-AI-Collective/fixes-20230711
various fixes
2023-06-11 09:49:23 -04:00
Wing Lian
77762a5d6b get rid of some configs, formalize pythioa lora config 2023-06-11 09:41:41 -04:00
Wing Lian
14668fa54e new validation for mpt w grad checkpoints 2023-06-11 09:26:10 -04:00
AngainorDev
b565ecf0a1 Fix strict and Lint 2023-06-11 15:23:38 +02:00
Wing Lian
fe0b76854e match up gradient checkpointing when using lora w config 2023-06-11 09:20:40 -04:00
NanoCode012
e944311442 Merge pull request #186 from akj2018/main
Update FAQS.md
2023-06-11 19:45:06 +09:00
Akshay Jain
e3e7b52a5b Update FAQS.md
Converted (```) to single backtick (') uniformly.
2023-06-10 23:36:14 -07:00
NanoCode012
974dc00a7d Fix set mem_id for inference and refactor 2023-06-11 14:00:54 +09:00
NanoCode012
572d1141e6 Set mem cache args on inference 2023-06-11 12:05:37 +09:00
NanoCode012
a6190c8094 Clean up landmark patching 2023-06-11 11:59:03 +09:00
NanoCode012
563b6d89e6 Fix undefined LlamaForCausalLM and del try except 2023-06-11 11:58:31 +09:00
Wing Lian
cd0a6f6027 peft no longer needs device_map 2023-06-10 22:50:09 -04:00
Akshay Jain
0e664a5ebc Update FAQS.md
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 19:26:12 -07:00
Akshay Jain
dd7d16d2eb Update FAQS.md
Updated FAQS.md with backticks around error message
2023-06-10 19:15:50 -07:00
NanoCode012
e285e24f7f Address PR suggestion
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-06-11 10:52:12 +09:00
NanoCode012
919727b4d7 Refactor landmark attention patch 2023-06-11 10:51:05 +09:00
Akshay Jain
5ffefee37f Update FAQS.md
Update FAQS.md with the following statement

Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c
/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598:  arrow::fs::FinalizeS3 was not called even though S3 was initialized.  This could lead to a segmentation fault at exit

try reinstalling bitsandbytes and transformers from source
2023-06-10 18:34:54 -07:00
Wing Lian
d9f713e4e3 Merge pull request #183 from OpenAccess-AI-Collective/inference-from-stdin
pass a prompt in from stdin for inference
2023-06-10 17:06:55 -04:00
Wing Lian
958da70376 fix formatting 2023-06-10 15:28:08 -04:00
Wing Lian
c4e4f8115c pass a prompt in from stdin for inference 2023-06-10 15:07:40 -04:00
Angainor Development
a808bf913f Fix missing cfg. 2023-06-10 20:28:49 +02:00
Wing Lian
01248253a3 Merge pull request #182 from OpenAccess-AI-Collective/fix-llama-ref
fix for local variable 'LlamaForCausalLM' referenced before assignment
2023-06-10 14:25:51 -04:00
Wing Lian
759e8673ce Update scripts/finetune.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 14:25:21 -04:00
Wing Lian
0c6f928601 address PR feedback 2023-06-10 14:23:56 -04:00
Wing Lian
eea2731a5e add streaming dataset support for pretraining datasets 2023-06-10 14:23:56 -04:00
Wing Lian
1db46a9c72 linting fix 2023-06-10 14:23:56 -04:00
Wing Lian
ab5cd28acf more gpt-neox long ctx fixes 2023-06-10 14:23:55 -04:00
Wing Lian
1a82082e91 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-10 14:23:55 -04:00
Wing Lian
1210dc8fd5 more tweaks to do pre-training with bettertransformers 2023-06-10 14:23:55 -04:00
Wing Lian
488a67d75a experimental expansion of ctx len 2023-06-10 14:23:53 -04:00
Wing Lian
71a43f8479 add validation/warning for bettertransformers and torch version 2023-06-10 14:22:31 -04:00
Wing Lian
39619028a3 use pythia-12b, neox-20b is flaky 2023-06-10 14:22:30 -04:00
Wing Lian
8792199799 add flash attn context for efficient training and attempt setting model to train mode: 2023-06-10 14:22:30 -04:00
Wing Lian
1edc30c786 add support for opimum bettertransformers 2023-06-10 14:22:30 -04:00
Wing Lian
14163c15d9 fix for local variable 'LlamaForCausalLM' referenced before assignment 2023-06-10 14:11:13 -04:00
Wing Lian
41e4f6ca31 Merge pull request #181 from OpenAccess-AI-Collective/xpos-rope
add support to extend context with xpos rope
2023-06-10 14:04:03 -04:00
Angainor Development
79e2a6f140 Merge branch 'main' into patch-1 2023-06-10 19:07:54 +02:00
Angainor Development
c2508987a6 Remove explicit definition of cfg.inference 2023-06-10 19:06:10 +02:00
Wing Lian
215d775147 Merge pull request #180 from Glavin001/feat/stream-inference
Add streaming inference & fix stopping at EOS
2023-06-10 12:04:34 -04:00
Wing Lian
f36e227eaf formatting for linter 2023-06-10 12:00:52 -04:00
Wing Lian
5878bb1f3a add option to readme 2023-06-10 11:57:41 -04:00
Wing Lian
a03a7d7d8b add support to extend context with xpos rope 2023-06-10 10:29:46 -04:00
Glavin Wiechert
fec6bcc3e6 Add streaming inference & fix stopping at EOS 2023-06-10 08:14:47 +00:00
Wing Lian
931e606459 Merge pull request #179 from OpenAccess-AI-Collective/fix-max_seq_len
fix for max sequence len across different model types
2023-06-09 20:52:03 -04:00
Wing Lian
7f09106437 fix for max sequence len across different model types 2023-06-09 20:42:33 -04:00
NanoCode012
6b50200234 Merge pull request #178 from PocketDocLabs/main
Update README.md to reflect current gradient checkpointing support
2023-06-10 08:26:48 +09:00
PocketDocLabs
16f9e28048 Update README.md to reflect current gradient checkpointing support
Previously the readme stated gradient checkpointing was incompatible with 4-bit lora in the current implementation however this is no longer the case. I have replaced the warning with a link to the hugging face documentation on gradient checkpointing.
2023-06-09 16:10:58 -07:00
NanoCode012
b9083a7fc1 Merge pull request #176 from NanoCode012/fix/peft-import
Fix backward compat for peft
2023-06-10 07:56:35 +09:00
NanoCode012
aefb2fc681 Fix backward compat for peft 2023-06-10 07:46:36 +09:00
NanoCode012
b5aa8d854c Merge pull request #169 from NanoCode012/feat/landmark
Feat: Add landmark attention
2023-06-10 07:26:06 +09:00
NanoCode012
4d6490bce2 Merge pull request #171 from OpenAccess-AI-Collective/NanoCode012-falcon-lora-matrix
Fix falcon support lora
2023-06-09 17:58:22 +09:00
NanoCode012
b242b69e10 Fix falcon support lora 2023-06-09 17:50:16 +09:00
NanoCode012
320beb20f4 Merge pull request #170 from OpenAccess-AI-Collective/NanoCode012-lambdalabs-fix
Feat: Improve lambda labs instruction
2023-06-09 16:52:27 +09:00
Angainor Development
bd3b537344 Feed cfg.inference 2023-06-09 08:59:05 +02:00
Angainor Development
813cfa4c14 WIP: Rely on cfg.inference 2023-06-09 08:49:32 +02:00
NanoCode012
2e13ceff37 Improve lambda labs instruction 2023-06-09 15:03:08 +09:00
NanoCode012
2a801b001a Fix grad checkpoint and outputs param 2023-06-09 14:28:44 +09:00
NanoCode012
e44c9e0b3e Fix patching via import instead of hijacking 2023-06-09 14:27:24 +09:00
NanoCode012
55b8542de8 Feat: Add landmark attention 2023-06-09 12:54:08 +09:00
Wing Lian
febe902517 Merge pull request #168 from bratao/main
Disable Wandb if no wandb project is specified
2023-06-08 22:05:56 -04:00
Bruno Cabral
f4df266842 Disable Wandb 2023-06-08 21:02:02 -03:00
NanoCode012
281dc3df59 Merge pull request #167 from NanoCode012/fix/redundant-save-eval-steps
Fix: Refactor out unmodified save_steps and eval_steps
2023-06-09 01:39:33 +09:00
NanoCode012
2ef4634d45 Refactor out unmodified save_steps and eval_steps 2023-06-09 01:23:13 +09:00
NanoCode012
7eae90333e Merge pull request #166 from NanoCode012/fix/seed
Fix: Set to use cfg.seed or 42 for seed
2023-06-09 01:15:08 +09:00
NanoCode012
c8242de725 Merge pull request #132 from utensil/falcon-7b-qlora
Axolotl supports falcon + qlora
2023-06-09 01:14:03 +09:00
NanoCode012
2cfe9e9b16 Set to use cfg.seed or 42 for backward compat 2023-06-09 01:02:36 +09:00
Utensil
79a8f52181 Trim trailing whitespace 2023-06-08 23:48:57 +08:00
NanoCode012
afaa0d2c01 Merge pull request #164 from NanoCode012/fix/falcon-fsdp-validate
Fix: Validate falcon with fsdp
2023-06-09 00:44:12 +09:00
NanoCode012
bfd27ba55e Fix failing test 2023-06-09 00:35:03 +09:00
NanoCode012
babf0fdb71 Validate falcon with fsdp 2023-06-09 00:29:04 +09:00
Utensil
a52f4816b0 Default wandb_project to empty as suggested
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-08 23:04:19 +08:00
NanoCode012
81911d112c Merge pull request #163 from NanoCode012/feat/matmul-tf32
Feat: Set matmul tf32=True when tf32 passed
2023-06-09 00:01:31 +09:00
NanoCode012
52765ac588 Set matmul tf32 2023-06-08 23:41:12 +09:00
NanoCode012
73e9ea4069 Merge pull request #143 from NanoCode012/fix/deprecate-prepare-8bit-training
Fix future deprecate prepare_model_for_int8_training
2023-06-08 23:07:53 +09:00
NanoCode012
f8d379883d Merge pull request #162 from NanoCode012/fix/custom-prompt-readme
Fix: Move custom prompts out of hidden
2023-06-08 23:05:17 +09:00
NanoCode012
04a1b77307 Merge pull request #161 from NanoCode012/fix/peft-setup
Fix: Update peft and gptq instruction
2023-06-08 23:01:53 +09:00
NanoCode012
2097a09d2d Move custom prompts out of hidden 2023-06-08 22:53:56 +09:00
NanoCode012
cfff94b123 Add peft install for quickstart 2023-06-08 22:50:20 +09:00
NanoCode012
2b222de5b6 Update peft and gptq instruction 2023-06-08 22:48:26 +09:00
NanoCode012
df9528f865 Fix future deprecate prepare_model_for_int8_training 2023-06-08 21:42:10 +09:00
Angainor Development
193c73bce0 Fix training over existing lora
When training with Lora, and starting with an existing lora weights, current code produces a model with 0 trainable params and training can't work.
Adding the "is_trainable" param allows the loaded peft to be trained and fixes the bug.
2023-06-08 09:18:58 +02:00
Wing Lian
6abfd87d44 Merge pull request #158 from OpenAccess-AI-Collective/prompter-fixes
fix camel ai, add guanaco/oasst mapping for sharegpt
2023-06-07 11:02:30 -04:00
Wing Lian
59bb2197ed fix camel ai, add guanaco/oasst mapping for sharegpt 2023-06-07 09:51:29 -04:00
Wing Lian
9a02e7e1ff Merge pull request #155 from OpenAccess-AI-Collective/misc-fixes
new prompters, misc fixes for output dir missing using fsdp, and changing max seq len
2023-06-06 16:52:39 -04:00
Wing Lian
5b33e295bd update docs 2023-06-05 22:48:16 -04:00
Wing Lian
4ac9e251b7 new prompters, misc fixes for output dir missing using fsdp, and changing max seq len 2023-06-05 22:41:00 -04:00
Utensil
c9c050316f Default micro_batch_size to 1 for a safer start 2023-06-03 17:26:33 +08:00
Utensil
ca11ae9689 Add comments/alternatives for falcon-qlora configs 2023-06-03 15:04:02 +08:00
Wing Lian
328c3bce96 Merge pull request #149 from OpenAccess-AI-Collective/docker-clone-axolotl
clone in docker
2023-06-02 15:15:30 -04:00
Wing Lian
5cd2126439 shallow clone 2023-06-02 14:54:28 -04:00
Wing Lian
12620f3089 clone in docker 2023-06-02 14:52:50 -04:00
Wing Lian
4ab0c8b201 Merge pull request #148 from OpenAccess-AI-Collective/fix-device-load 2023-06-02 14:37:17 -04:00
Wing Lian
74ebbf4371 fix device map 2023-06-02 14:29:08 -04:00
Wing Lian
76a70fd739 Merge pull request #147 from OpenAccess-AI-Collective/winglian-rocker-images
Update README.md for correct image tags
2023-06-02 14:10:40 -04:00
Wing Lian
618816d4df Update README.md for correct image tags 2023-06-02 14:10:23 -04:00
Wing Lian
91992cb8f5 Merge pull request #146 from FarisHijazi/main
added docker-compose file
2023-06-02 13:58:23 -04:00
FarisHijazi
84169d15b3 added docker-compose file 2023-06-02 18:17:43 +03:00
Wing Lian
ecfe8d0a1a Merge pull request #142 from NanoCode012/feat/custom-prompt-readme
Feat: Add custom prompt readme and add missing prompt strategies to Readme
2023-06-02 07:21:04 -04:00
Wing Lian
eee44a3b47 Merge pull request #141 from NanoCode012/feat/lambdalabs-readme
Feat: Add lambdalabs instruction
2023-06-02 07:20:12 -04:00
NanoCode012
078a43eef8 Remove redundant instruction 2023-06-02 12:30:11 +09:00
NanoCode012
33e1890086 Add pygmalion 2023-06-02 12:27:51 +09:00
NanoCode012
1c38253692 Add other prompt_strategies 2023-06-02 12:24:44 +09:00
NanoCode012
496b83f778 Add short instruction for custom prompts 2023-06-02 12:16:20 +09:00
NanoCode012
ff68a95781 Add lambdalabs instruction 2023-06-02 12:09:40 +09:00
Utensil
fb3d40f197 falcon + qlora + xformer mbs 40 gas 2 on A6000 2023-06-01 18:29:20 +08:00
NanoCode012
288fd62431 Merge pull request #135 from NanoCode012/fix/grad-accu-readme
Fix: Update doc for grad_accu and add validation tests for batch size
2023-06-01 06:33:05 +09:00
NanoCode012
3c71c8debe Update doc for grad_accu and add validation tests for batch size 2023-06-01 06:13:47 +09:00
Wing Lian
a6f5e5eaec Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix
fix batch size calculation
2023-05-31 14:24:48 -04:00
Wing Lian
5a631b305b fix batch size calculation 2023-05-31 14:11:32 -04:00
Wing Lian
f94dd626f0 Merge pull request #130 from OpenAccess-AI-Collective/gas
swap batch size for gradient accumulation steps to decouple from num gpu
2023-05-31 13:03:51 -04:00
Wing Lian
5079753b7a Merge pull request #131 from OpenAccess-AI-Collective/fix-packing-mask
fix packing so that concatenated sequences reset the attention
2023-05-31 13:03:37 -04:00
Wing Lian
0136f510f2 don't worry about duplicate code here 2023-05-31 12:05:43 -04:00
Utensil
72bf8aafb6 Create config-7b-qlora.yml 2023-06-01 00:00:37 +08:00
Utensil
8afb0fbaba Axolotl supports falcon + qlora 2023-05-31 23:58:40 +08:00
Wing Lian
9b8585dc70 fix packing so that concatenated sequences reset the attention 2023-05-31 11:38:52 -04:00
Wing Lian
8eb5811d4e Merge pull request #129 from OpenAccess-AI-Collective/builder-badge
add badge info to readme
2023-05-31 10:37:59 -04:00
Wing Lian
e0011fdf55 Fix base builder, missing tags 2023-05-31 09:52:03 -04:00
Wing Lian
6e9e98720e Merge pull request #127 from OpenAccess-AI-Collective/py310-docker-runpod
add py310 support from base image
2023-05-31 09:39:42 -04:00
Wing Lian
c2a0792680 swap batch size for gradient accumulation steps to decouple from num gpu 2023-05-31 09:38:12 -04:00
Wing Lian
b267d24a2b add badge info to readme 2023-05-31 09:28:44 -04:00
Wing Lian
5c3f5db38b Add files via upload 2023-05-31 09:22:54 -04:00
Wing Lian
e3d03745ba add py310 support from base image 2023-05-31 09:07:28 -04:00
NanoCode012
fac46002d4 Merge pull request #119 from NanoCode012/feat/update-inference
Feat(inference): Swap to GenerationConfig
2023-05-31 14:09:18 +09:00
NanoCode012
33d40179ba Increase max_new_tokens
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-05-31 14:04:49 +09:00
Wing Lian
dcb03d6da4 Merge pull request #114 from OpenAccess-AI-Collective/accelerate-dep
Add accelerate dep
2023-05-31 00:47:17 -04:00
NanoCode012
0e4be625ae Merge pull request #118 from NanoCode012/feat/torch-readme
Fix(readme): Fix torch missing from readme
2023-05-31 13:29:41 +09:00
NanoCode012
bdc4bd7d4e Update README.md 2023-05-31 13:24:28 +09:00
Wing Lian
2d0ba3b818 Merge pull request #124 from OpenAccess-AI-Collective/xformers-fix
copy xformers attn from ooba since we removed dep on alpaca_lora_4bit
2023-05-31 00:11:40 -04:00
Wing Lian
c7021e191f Merge pull request #120 from OpenAccess-AI-Collective/model-from-path
split up llama model loading so config can be loaded from base config and models can be loaded from a path
2023-05-31 00:08:38 -04:00
Wing Lian
c56818b119 don't worry about dupes 2023-05-31 00:06:47 -04:00
Wing Lian
2675fb756e update readme for SDP 2023-05-31 00:04:54 -04:00
Wing Lian
1076bcbbca Update src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-05-31 00:00:19 -04:00
Wing Lian
2daa6835f0 Update src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-05-30 23:59:05 -04:00
Wing Lian
e3c494ca7b remove unused import and update readme 2023-05-30 23:55:45 -04:00
Wing Lian
ad0ea6aaab black formatting
ignore copied file
fix linting
2023-05-30 23:50:29 -04:00
Wing Lian
876edd83d0 Merge pull request #123 from OpenAccess-AI-Collective/bas-batch
add support for gradient accumulation steps
2023-05-30 23:45:29 -04:00
Wing Lian
6cb2310592 copy xformers attn from ooba since we removed dep on alpaca_lora_4bit 2023-05-30 23:34:36 -04:00
Wing Lian
6fa40bf8ad black formatting 2023-05-30 23:33:37 -04:00
Wing Lian
3aad5f3b3e add support for gradient accumulation steps 2023-05-30 23:24:37 -04:00
Wing Lian
39a208c2bc fix up tokenizer config, isort fix 2023-05-30 23:00:02 -04:00
Wing Lian
2520ecd6df split up llama model loading so config can be loaded from base config and models can be loaded from a path 2023-05-30 22:32:44 -04:00
Wing Lian
c5b0af1a7e define python version (3.10) explicitly as string in yaml 2023-05-30 22:23:35 -04:00
NanoCode012
988aeb9c34 Feat: Swap to GenerationConfig 2023-05-31 10:48:19 +09:00
NanoCode012
cf61f14bff FIx(readme): Fix torch missing from readme 2023-05-31 10:28:49 +09:00
Wing Lian
0abcd71a85 Merge pull request #115 from OpenAccess-AI-Collective/docker-version-fixes
docker fixes: py310, fix cuda arg in deepspeed
2023-05-30 18:11:26 -04:00
Wing Lian
c43c5c84ff py310, fix cuda arg in deepspeed 2023-05-30 18:02:34 -04:00
Wing Lian
36ec6e1a0e Add accelerate dep 2023-05-30 16:36:13 -04:00
Wing Lian
13b80937f9 add release draft template for gh
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-05-30 15:10:19 -04:00
Wing Lian
bbc5bc5791 Merge pull request #108 from OpenAccess-AI-Collective/docker-gptq
default to qlora support, make gptq specific image
2023-05-30 15:07:04 -04:00
Wing Lian
4df9da74e3 Merge pull request #105 from viktoriussuwandi/viktoriussuwandi-patch
Viktoriussuwandi patch
2023-05-30 15:05:23 -04:00
Wing Lian
2531ea24c1 Merge pull request #106 from fearnworks/qlora-openllama-3b-example
Qlora openllama 3b example
2023-05-30 15:05:05 -04:00
Wing Lian
01a75fd027 Merge pull request #98 from NanoCode012/feat/pre-commit
Add pre-commit: black+flake8+pylint+mypy+isort+bandit
2023-05-30 14:57:15 -04:00
NanoCode012
b81c97ff76 Fix pre-commit for rebased files 2023-05-31 03:01:38 +09:00
NanoCode012
594e72b6e8 Fix incorrect rebase 2023-05-31 02:58:50 +09:00
NanoCode012
25eeeeba0b Fix sharegpt prompt 2023-05-31 02:55:21 +09:00
Wing Lian
cfcc549f6b fix relative path for fixtures 2023-05-31 02:55:21 +09:00
NanoCode012
a1f9850b91 Fix security issue or ignore false positives 2023-05-31 02:53:53 +09:00
NanoCode012
83d29209f7 Add bandit 2023-05-31 02:53:53 +09:00
NanoCode012
d011422200 Add isort 2023-05-31 02:53:53 +09:00
NanoCode012
b1cc54b14a Update pip install to also setup tests 2023-05-31 02:53:53 +09:00
NanoCode012
c17dae6d07 Update src/axolotl/prompt_strategies/alpaca_instruct.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-05-31 02:53:53 +09:00
NanoCode012
37293dce07 Apply isort then black 2023-05-31 02:53:53 +09:00
NanoCode012
96e8378692 Delete extract_lora.py 2023-05-31 02:53:53 +09:00
NanoCode012
e9650d3ae4 Fix mypy typing 2023-05-31 02:53:53 +09:00
NanoCode012
f1232b35ba Update mypy dependencies 2023-05-31 02:53:53 +09:00
NanoCode012
741a3f2edc Add mypy 2023-05-31 02:53:53 +09:00
NanoCode012
0dd35c74af Ignore unsupported-binary-operation 2023-05-31 02:53:53 +09:00
NanoCode012
db288e9b13 Set python version 2023-05-31 02:53:53 +09:00
NanoCode012
be22551435 Fix unsupported operand type(s) for | 2023-05-31 02:53:53 +09:00
NanoCode012
b832a0ac62 Black formatting 2023-05-31 02:53:53 +09:00
NanoCode012
afb31e13a3 Add badge and update contribution section 2023-05-31 02:53:53 +09:00
NanoCode012
1bf1f59a41 Move black to dev requirements 2023-05-31 02:53:53 +09:00
NanoCode012
8e46c0fb0d Refactor duplicate code between Prompter and Pygmalion 2023-05-31 02:53:53 +09:00
NanoCode012
1f3c3f5ea0 Lint validation 2023-05-31 02:53:53 +09:00
NanoCode012
0e952889dc Lint test_dict 2023-05-31 02:53:53 +09:00
NanoCode012
9c6750a075 Lint wandb 2023-05-31 02:53:53 +09:00
NanoCode012
c2dbf2c526 Lint validation 2023-05-31 02:53:53 +09:00
NanoCode012
e6b57decbd Lint tokenization 2023-05-31 02:53:53 +09:00
NanoCode012
fe1f4c4e7d Lint schedulers 2023-05-31 02:53:53 +09:00
NanoCode012
dae14e5951 Ignore too-many-instance-attributes 2023-05-31 02:53:53 +09:00
NanoCode012
633ff2150f Lint dict 2023-05-31 02:53:53 +09:00
NanoCode012
5d86137f70 Lint prompt_tokenizers 2023-05-31 02:53:53 +09:00
NanoCode012
01c8a333b3 Lint pygmalion 2023-05-31 02:53:53 +09:00
NanoCode012
7eb33a77dd Lint test_prompters 2023-05-31 02:53:53 +09:00
NanoCode012
1645a4ddd5 Lint creative_acr 2023-05-31 02:53:53 +09:00
NanoCode012
145b060cbe Lint alpaca_instruct 2023-05-31 02:53:53 +09:00
NanoCode012
8cc0aadcb8 Lint alpaca_chat 2023-05-31 02:53:53 +09:00
NanoCode012
6abb7f6a16 Lint datasets 2023-05-31 02:53:53 +09:00
NanoCode012
de2406c488 Lint convert.py 2023-05-31 02:53:53 +09:00
NanoCode012
8b617cc7f6 Lint setup.py 2023-05-31 02:53:53 +09:00
NanoCode012
ddb86ea821 Lint trainer.py 2023-05-31 02:53:53 +09:00
NanoCode012
1a2bd7ff62 Ignore too-few-public-methods 2023-05-31 02:53:23 +09:00
NanoCode012
82971e1565 Lint finetune.py 2023-05-31 02:53:23 +09:00
NanoCode012
f4e5d86268 Lint models.py 2023-05-31 02:53:23 +09:00
NanoCode012
daf47ccf45 Refactor disable pylint 2023-05-31 02:53:23 +09:00
NanoCode012
545cfeb5c7 Refactor error code to use full error message 2023-05-31 02:53:23 +09:00
NanoCode012
69722aeef4 Remove fixme disable 2023-05-31 02:53:23 +09:00
NanoCode012
5658717dbd Remove disable too many arg 2023-05-31 02:53:23 +09:00
NanoCode012
e8717d3bef Remove disable 2023-05-31 02:53:23 +09:00
NanoCode012
54c3b5b25f Ignore too-many-arguments 2023-05-31 02:53:23 +09:00
NanoCode012
5062eca069 Lint callbacks.py 2023-05-31 02:53:23 +09:00
NanoCode012
cb4f0e9342 Lint prompters.py 2023-05-31 02:53:23 +09:00
NanoCode012
4c0eddb3f8 Refactor 2023-05-31 02:53:23 +09:00
NanoCode012
1c60c10e00 Lint flash_attn.py 2023-05-31 02:53:23 +09:00
NanoCode012
903ea3080d Fix lint 2023-05-31 02:53:23 +09:00
NanoCode012
cb7cd3429f Fix data.py lint 2023-05-31 02:53:23 +09:00
NanoCode012
d57ba56746 Ignore import and too many * pylint errors 2023-05-31 02:53:23 +09:00
NanoCode012
c3a4697016 Update ignores 2023-05-31 02:53:22 +09:00
NanoCode012
392dfd9b07 Lint and format 2023-05-31 02:53:22 +09:00
NanoCode012
a98deb31a6 Add config files 2023-05-31 02:53:22 +09:00
NanoCode012
36596adaf7 Add pre-commit: black+flake8+pylint 2023-05-31 02:53:22 +09:00
jphillips
6cee881d64 Update examples/qlora-openllama-3b/README.md
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-05-30 09:33:33 -05:00
Wing Lian
48612f8376 cleanup from pr feedback 2023-05-30 09:56:30 -04:00
Wing Lian
d91a769b88 update docs 2023-05-29 20:37:32 -04:00
Wing Lian
6ef96f569b default to qlora support, make gptq specific image 2023-05-29 20:34:41 -04:00
jphillips
ac85c0ed36 Add Readme, Clean up comments 2023-05-29 14:35:58 -05:00
jphillips
f1fbf666f7 Merge branch 'main' of https://github.com/OpenAccess-AI-Collective/axolotl into qlora-openllama-3b-example 2023-05-29 09:09:43 -05:00
jphillips
370d057096 Add qlora-openllama-3b example 2023-05-29 09:07:46 -05:00
Viktorius Suwandi
e0ccaccce2 Update wandb_log_model on vicuna_13B_4bit_reflect.yml 2023-05-29 16:34:13 +07:00
Viktorius Suwandi
15e57ba6ee Update wandb_log_model on config.yml 2023-05-29 16:33:20 +07:00
Viktorius Suwandi
4eb68ac3f7 Update wandb_log_model on config-3b.yml 2023-05-29 16:32:49 +07:00
Viktorius Suwandi
b6a539b53c Update wandb_log_model on cerebras_1_3B_alpaca.yml 2023-05-29 16:32:20 +07:00
Viktorius Suwandi
abddcf4dfe Update wandb_log_model on pythia_1_2B_alpaca.yml 2023-05-29 16:31:53 +07:00
Viktorius Suwandi
15aabd2903 Update wandb_log_model on llama_7B_jeopardy.yml 2023-05-29 15:44:01 +07:00
Viktorius Suwandi
232b931081 Update wandb_log_model on llama_65B_alpaca.yml 2023-05-29 15:43:43 +07:00
Viktorius Suwandi
0736f4f9c1 Update wandb_log_model on llama_13B_alpaca.yml 2023-05-29 15:43:20 +07:00
Viktorius Suwandi
d77d736631 Update wandb_log_model on llama_7B_alpaca.yml 2023-05-29 15:43:01 +07:00
Viktorius Suwandi
fad06befee Update wandb_log_model on config.yml 2023-05-29 15:42:38 +07:00
Viktorius Suwandi
2aacf75ee1 Update wandb_log_model on galactica_1_3B.yml 2023-05-29 15:42:19 +07:00
Viktorius Suwandi
71871345a6 Update wandb_log_model on llama_7B_4bit.yml 2023-05-29 15:41:59 +07:00
Viktorius Suwandi
0d14e951a8 Update wandb_log_model on stability_3b.yml 2023-05-29 15:41:42 +07:00
Viktorius Suwandi
84fc217f79 Update wandb_log_model on gpt_neox_20b.yml 2023-05-29 15:41:24 +07:00
Viktorius Suwandi
f317296259 Update wandb_log_model on quickstart.yml 2023-05-29 15:40:58 +07:00
Viktorius Suwandi
42a971df32 Update wandb_log_model on sample.yml 2023-05-29 15:39:42 +07:00
89 changed files with 5083 additions and 1218 deletions

3
.bandit Normal file
View File

@@ -0,0 +1,3 @@
[bandit]
exclude = tests
skips = B101

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

31
.github/release-drafter.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name-template: 'v$RESOLVED_VERSION'
tag-template: 'v$RESOLVED_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.
version-resolver:
major:
labels:
- 'major'
minor:
labels:
- 'minor'
patch:
labels:
- 'patch'
default: patch
template: |
## Whats Changed
$CHANGES

View File

@@ -12,16 +12,29 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted
strategy:
fail-fast: false
matrix:
include:
- cuda: cu118
- cuda: "118"
cuda_version: 11.8.0
cuda_version_bnb: "118"
python_version: "3.9"
pytorch: 2.0.0
- cuda: cu117
cuda_version: 11.7.0
cuda_version_bnb: "117"
axolotl_extras:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
axolotl_extras:
- cuda: "117"
cuda_version: 11.7.1
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -43,12 +56,13 @@ jobs:
context: .
file: ./docker/Dockerfile-base
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDA_VERSION_BNB=${{ matrix.cuda_version_bnb }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}

View File

@@ -11,14 +11,29 @@ jobs:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.0
cuda_version: 11.7.1
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
runs-on: self-hosted
steps:
- name: Checkout
@@ -40,10 +55,10 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
@@ -56,10 +71,24 @@ jobs:
include:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.0
cuda_version: 11.7.1
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
runs-on: self-hosted
steps:
- name: Checkout
@@ -81,10 +110,10 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
file: ./docker/Dockerfile-runpod
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

16
.github/workflows/pre-commit.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
name: pre-commit
on:
pull_request:
push:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

View File

@@ -7,6 +7,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.9", "3.10"]
timeout-minutes: 10

2
.gitignore vendored
View File

@@ -160,4 +160,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.idea/

2
.isort.cfg Normal file
View File

@@ -0,0 +1,2 @@
[settings]
profile=black

39
.mypy.ini Normal file
View File

@@ -0,0 +1,39 @@
[mypy]
exclude = venv
[mypy-alpaca_lora_4bit.*]
ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-flash_attn.*]
ignore_missing_imports = True
[mypy-huggingface_hub]
ignore_missing_imports = True
[mypy-transformers.*]
ignore_missing_imports = True
[mypy-peft]
ignore_missing_imports = True
[mypy-bitsandbytes]
ignore_missing_imports = True
[mypy-datasets]
ignore_missing_imports = True
[mypy-fire]
ignore_missing_imports = True
[mypy-setuptools]
ignore_missing_imports = True
[mypy-addict]
ignore_missing_imports = True
[mypy-xformers.*]
ignore_missing_imports = True

42
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,42 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v2.17.4
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
- id: mypy
additional_dependencies:
[
'types-PyYAML',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
hooks:
- id: bandit
args: [
'--ini',
'.bandit',
]

14
.pylintrc Normal file
View File

@@ -0,0 +1,14 @@
[MASTER]
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,

View File

@@ -2,3 +2,6 @@
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.

267
README.md
View File

@@ -9,36 +9,40 @@
<p>
Go ahead and axolotl questions!!
</p>
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
</div>
</div>
## Axolotl supports
| | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
|---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | |
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
| falcon | ✅ | | ✅ | ❌ | | ❌ | ❌ | ✅ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
## Quickstart ⚡
**Requirements**: Python 3.9.
**Requirements**: Python 3.9 and Pytorch 2.0.
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
pip3 install -e .[int4]
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
accelerate config
# finetune lora
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
# inference
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
--inference --lora_model_dir="./lora-out"
```
@@ -48,18 +52,83 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
- Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
```
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
- `winglian/axolotl:dev`: dev branch (not usually up to date)
Or run on the current files for development:
```sh
docker compose up -d
```
- `winglian/axolotl:dev`: dev branch
- `winglian/axolotl-runpod:main`: for runpod
- Conda/Pip venv
1. Install python **3.9**
2. Install python dependencies with ONE of the following:
- `pip3 install -e .[int4]` (recommended)
- `pip3 install -e .[int4_triton]`
- `pip3 install -e .`
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following:
- Recommended, supports QLoRA, NO gptq/int4 support
```bash
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
```
- gptq/int4 support, NO QLoRA
```bash
pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
```
- LambdaLabs
<details>
<summary>Click to Expand</summary>
1. Install python
```bash
sudo apt update
sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.9
```
2. Install pip
```bash
wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
```
3. Install torch
```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
```
4. Axolotl
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install -e . # change depend on needs
pip3 install protobuf==3.20.3
pip3 install -U requests
pip3 install -U --ignore-installed psutil
pip3 install -U scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
```
5. Set path
```bash
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
```
</details>
### Dataset
@@ -69,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations
- `sharegpt:chat`: conversations
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -110,13 +179,70 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"article": "...", "summary": "..."}
```
> Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch!
- `alpaca_chat`: basic instruct for alpaca chat
```json
{"instruction": "...", "input": "...", "response": "..."}
```
- `alpaca_chat.load_qa`: question and answer for alpaca chat
```json
{"question": "...", "answer": "..."}
```
- `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers
```json
{"instruction": "...", "input": "...", "response": "..."}
```
- `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai
```json
{"message_1": "...", "message_2": "..."}
```
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
```json
{"system_prompt": "...", "question": "...", "response": "..."}
```
- `context_qa`: in context question answering from an article
```json
{"article": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json
{"article": "...", "unanswerable_question": "..."}
```
- `creative_acr.load_answer`: instruction and revision
```json
{"instruction": "...", "revision": "..."}
```
- `creative_acr.load_critique`: critique
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "..."}
```
- `creative_acr.load_revise`: critique and revise
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
```
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```json
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
</details>
#### How to add custom prompts
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
Optionally, download some datasets, see [data/README.md](data/README.md)
### Config
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
@@ -129,10 +255,18 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
- dataset
```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets:
- path: vicgalle/alpaca-gpt4 # local or huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
# local
datasets:
- path: json
data_files: data.jsonl # or json
type: alpaca # format from earlier
sequence_len: 2048 # max token length / prompt
```
- loading
@@ -142,6 +276,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
bf16: true # require >=ampere
fp16: true
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
Note: Repo does not do 4-bit quantization.
@@ -169,12 +305,19 @@ base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
model_type: AutoModelForCausalLM
# Corresponding tokenizer for the model AutoTokenizer is a good choice
tokenizer_type: AutoTokenizer
# Trust remote code for untrusted source
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -195,10 +338,10 @@ tf32: true # require >=ampere
# a list of one or more datasets to finetune the model with
datasets:
# this can be either a hf dataset, or relative path
# hf dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format OR format:prompt_style (chat/instruct)
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
data_files: # path to source data files
shards: # number of shards to split data into
@@ -207,6 +350,8 @@ datasets:
dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
@@ -258,20 +403,25 @@ wandb_log_model: # 'checkpoint'
output_dir: ./completed-model
# training hyperparameters
batch_size: 8
gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
logging_steps:
save_steps:
eval_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
# does not work with current implementation of 4-bit LoRA
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
@@ -293,11 +443,27 @@ log_sweep_max_lr:
optimizer:
# specify weight decay
weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# whether to bettertransformers
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only
xpos_rope:
# resume from a specific checkpoint dir
resume_from_checkpoint:
@@ -365,11 +531,16 @@ Pass the appropriate flag to the train command:
- Pretrained LORA:
```bash
--inference --lora_model_dir ./completed-model
--inference --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
--inference --base_model ./completed-model
--inference --base_model="./completed-model"
```
- Full weights finetune w/ a prompt from a text file:
```bash
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
```
### Merge LORA to base
@@ -380,6 +551,12 @@ Add below flag to train command above
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
```
If you run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
```
## Common Errors 🧰
> Cuda out of memory
@@ -387,6 +564,7 @@ Add below flag to train command above
Please reduce any below
- `micro_batch_size`
- `eval_batch_size`
- `gradient_accumulation_steps`
- `sequence_len`
> RuntimeError: expected scalar type Float but found Half
@@ -397,12 +575,41 @@ Try set `fp16: true`
Try to turn off xformers.
## Need help? 🙋♂️
## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
```
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
## Community Showcase
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
PocketDoc Labs
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
## Contributing 🤝
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
```

View File

@@ -1,15 +0,0 @@
compute_environment: LOCAL_MACHINE
distributed_type: 'NO'
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -1,40 +0,0 @@
base_model: cerebras/Cerebras-GPT-1.3B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- c_attn
lora_fan_in_fan_out: false
wandb_project: pythia-1.4b-lora
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 32
micro_batch_size: 4
num_epochs: 5
learning_rate: 0.0003
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,41 +0,0 @@
base_model: facebook/galactica-1.3b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
adapter:
lora_model_dir:
sequence_len: 1024
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca
batch_size: 32
micro_batch_size: 16
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: false
tf32: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
tokens:
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,39 +0,0 @@
base_model: EleutherAI/gpt-neox-20b
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
datasets:
- path: nomic-ai/gpt4all-j-prompt-generations
type: alpaca
shards: 4
shards_index: 0
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- query_key_value
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: gpt4all-neox-20b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./gpt4all-neox-20b
batch_size: 48
micro_batch_size: 4
num_epochs: 5
learning_rate: 0.00003
lr_scheduler: one_cycle
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,39 +0,0 @@
base_model: huggyllama/llama-13b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt
dataset_prepared_path: last_run_prepared
val_set_size: 0.002
adapter:
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./llama-13b-sharegpt
batch_size: 64
micro_batch_size: 2
warmup_steps: 1000
save_steps:
eval_steps:
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience: 5
resume_from_checkpoint:
local_rank:

View File

@@ -1,44 +0,0 @@
base_model: huggyllama/llama-65b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-65b-lora
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca
batch_size: 128
micro_batch_size: 16
warmup_steps: 1000
save_steps:
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,45 +0,0 @@
base_model: decapoda-research/llama-7b-hf-int4
base_model_config: decapoda-research/llama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: tatsu-lab/alpaca # original alpaca dataset
type: alpaca
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-test
batch_size: 8
micro_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
load_4bit: true
xformers_attention: true
flash_attention:

View File

@@ -1,41 +0,0 @@
base_model: huggyllama/llama-7b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
type: alpaca
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca
batch_size: 128
micro_batch_size: 16
num_epochs: 5
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,45 +0,0 @@
base_model: decapoda-research/llama-7b-hf-int4
base_model_config: decapoda-research/llama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
datasets:
- path: tatsu-lab/alpaca # original alpaca dataset
type: alpaca
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 1024
max_packed_sequence_len: 1024
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-test
batch_size: 4
micro_batch_size: 1
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
gptq: true
xformers_attention: true
flash_attention:

View File

@@ -1,86 +0,0 @@
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# this can also be a relative path to a model on disk
base_model: decapoda-research/llama-7b-hf-int4
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model
base_model_config: decapoda-research/llama-7b-hf
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
model_type: AutoModelForCausalLM
# Corresponding tokenizer for the model AutoTokenizer is a good choice
tokenizer_type: AutoTokenizer
# whether you are training a 4-bit quantized model
load_4bit: true
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
# a list of one or more datasets to finetune the model with
datasets:
# this can be either a hf dataset, or relative path
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca
# axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
val_set_size: 0.04
# if you want to use lora, leave blank to train all parameters in original model
adapter: lora
# if you already have a lora model trained that you want to load, put that here
lora_model_dir:
# the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
max_packed_sequence_len: 1024
# lora hyperparameters
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
# wandb configuration if your're using it
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
# where to save the finsihed model to
output_dir: ./completed-model
# training hyperparameters
batch_size: 8
micro_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false
# Use CUDA bf16
bf16: true
# Use CUDA tf32
tf32: true
# does not work with current implementation of 4-bit LoRA
gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
# specify a scheduler to use with the optimizer. only one_cycle is supported currently
lr_scheduler:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
flash_attention:
# resume from a specific checkpoint dir
resume_from_checkpoint:
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
# be careful with this being turned on between different models
auto_resume_from_checkpoints: false
# don't mess with this, it's here for accelerate and torchrun
local_rank:

View File

@@ -1,56 +0,0 @@
base_model: stabilityai/stablelm-base-alpha-3b
base_model_config: stabilityai/stablelm-base-alpha-3b
load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter:
lora_model_dir:
sequence_len: 4096
max_packed_sequence_len: 4096
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: stable-alpaca-3b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./stable-alpaca-3b
batch_size: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100
eval_steps: 50
save_steps: 200
debug:
deepspeed:
weight_decay: 0.01
fsdp:
fsdp_config:
#tokens:
# pad_token: "[PAD]"
# bos_token: "<s>"
# eos_token: "</s>"
# unk_token: "<unk>"

View File

@@ -1,45 +0,0 @@
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_4bit: true
gptq_groupsize: 128
gptq_model_v1: false
datasets:
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
- path: data/alpaca_reflect_pruned.jsonl
type: reflection
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
# - o_proj
lora_fan_in_fan_out: false
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-reflect
batch_size: 8
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: false
early_stopping_patience: 3
resume_from_checkpoint:
local_rank:
flash_attention: true

View File

@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
## Convert the JSON data files to JSONL.
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
---

20
docker-compose.yaml Normal file
View File

@@ -0,0 +1,20 @@
# version: '3.8'
services:
axolotl:
build:
context: .
dockerfile: ./docker/Dockerfile
volumes:
- .:/workspace/axolotl
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
- WANDB_API_KEY=${WANDB_API_KEY}
deploy:
resources:
reservations:
devices:
- driver: nvidia
# count: 1
capabilities: [gpu]
command: tail -f /dev/null

View File

@@ -2,19 +2,25 @@ ARG BASE_TAG=main-base
FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
RUN apt-get update && \
apt-get install -y vim curl
WORKDIR /workspace
# The base image ships with `pydantic==1.8.2` which is not working
RUN python3 -m pip install -U --no-cache-dir pydantic
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"transformers @ git+https://github.com/huggingface/transformers.git@main"
RUN mkdir axolotl
COPY . axolotl/
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
pip install -e .[int4]
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[$AXOLOTL_EXTRAS]; \
else \
pip install -e .; \
fi
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -9,7 +9,7 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH="2.0.0"
ARG CUDA="cu118"
ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS flash-attn-builder
@@ -52,6 +52,8 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
FROM base-builder AS deepspeed-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
@@ -61,12 +63,12 @@ RUN git clone https://github.com/microsoft/DeepSpeed.git && \
FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA_VERSION_BNB="118"
ENV CUDA_VERSION_BNB=$CUDA_VERSION_BNB
ARG CUDA="118"
ENV CUDA=$CUDA
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
CUDA_VERSION=$CUDA_VERSION_BNB make cuda11x && \
CUDA_VERSION=$CUDA make cuda11x && \
python setup.py bdist_wheel
FROM base-builder
@@ -75,7 +77,7 @@ FROM base-builder
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
@@ -93,10 +95,6 @@ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/d
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"transformers @ git+https://github.com/huggingface/transformers.git@main" && \
pip3 install awscli && \
RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -0,0 +1,60 @@
base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false
load_in_4bit: true
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- c_fc
- c_attn
- c_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4
micro_batch_size: 4
num_epochs: 2
optimizer: paged_adamw_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"

View File

@@ -23,7 +23,7 @@ lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project: falcon-7b
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -0,0 +1,92 @@
# 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: QingyiSi/Alpaca-CoT
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 3
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 5
save_steps: 10
debug:
deepspeed:
weight_decay: 0.000001
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -23,7 +23,7 @@ lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project: falcon-7b
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

57
examples/gptj/qlora.yml Normal file
View File

@@ -0,0 +1,57 @@
base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false
load_in_4bit: true
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 2
optimizer: paged_adamw_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"

View File

@@ -3,6 +3,6 @@
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
```

View File

@@ -24,9 +24,9 @@ lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
wandb_log_model:
output_dir: ./llama-7b-lora-int4
batch_size: 1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit

View File

@@ -7,30 +7,28 @@ datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
sequence_len: 512
max_packed_sequence_len:
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: jeopardy-bot-7b
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
wandb_log_model:
output_dir: ./jeopardy-bot-7b
batch_size: 4
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
@@ -48,11 +46,10 @@ eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
weight_decay: 0.1
fsdp:
fsdp_config:
tokens:
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -22,9 +22,9 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
wandb_log_model:
output_dir: ./mpt-alpaca-7b
batch_size: 1
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit

View File

@@ -0,0 +1,16 @@
# openllama-3b
Basic full tune
```shell
accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
```
LoRA
```shell
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
```
QLoRA
```shell
accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
```

View File

@@ -0,0 +1,62 @@
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 256
max_packed_sequence_len:
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./openllama-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
float16: true
bf16: false
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,5 +1,5 @@
base_model: openlm-research/open_llama_3b_600bt_preview
base_model_config: openlm-research/open_llama_3b_600bt_preview
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
@@ -49,7 +49,7 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:

View File

@@ -0,0 +1,61 @@
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4
micro_batch_size: 4
num_epochs: 2
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -0,0 +1,49 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -1,36 +1,29 @@
base_model: EleutherAI/pythia-1.4b-deduped
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true
datasets:
- path: data/alpaca_data_gpt4.jsonl
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
- path: data/vicuna_cleaned.jsonl
type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:
sequence_len: 2048
lora_r: 8
sequence_len: 512
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- query_key_value
# - xxx
lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: pythia-1.4b-lora
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 48
wandb_log_model:
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 5
num_epochs: 3
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
@@ -39,3 +32,6 @@ tf32: True
early_stopping_patience:
resume_from_checkpoint:
local_rank:
weight_decay: 0.1
eval_steps: 20
logging_steps: 1

View File

@@ -1,7 +1,7 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM
tokenizer_type: GPTNeoXTokenizer
tokenizer_type: AutoTokenizer
trust_remote_code:
load_in_8bit: false
datasets:
@@ -23,7 +23,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_watch:
wandb_run_id:
wandb_log_model: checkpoint
wandb_log_model:
output_dir: ./redpajama-alpaca-3b
batch_size: 4
micro_batch_size: 1

BIN
image/axolotl-badge-web.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

3
requirements-dev.txt Normal file
View File

@@ -0,0 +1,3 @@
pre-commit
black
mypy

View File

@@ -4,16 +4,17 @@ bitsandbytes>=0.39.0
addict
fire
PyYAML==6.0
black
datasets
accelerate>=0.19.0
sentencepiece
wandb
einops
xformers
optimum
# qlora things
bert-score==0.3.13
evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
numba

View File

@@ -1,24 +1,38 @@
"""Module to convert json file to jsonl"""
import os
import sys
from pathlib import Path
from typing import Optional, Union
import fire
from typing import Optional
from axolotl.convert import (
FileReader,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.convert import *
def main(
input: Path,
file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
"""
Convert a json file to jsonl
"""
file_reader = FileReader()
writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None:
writer = StdoutWriter()
else:
@@ -28,7 +42,7 @@ def main(
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(input, output)
converter.convert(file, output)
if __name__ == "__main__":

View File

@@ -1,3 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
@@ -5,25 +7,28 @@ import random
import signal
import sys
from pathlib import Path
from typing import Optional, List, Dict, Any, Union
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -31,68 +36,101 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def choose_device(cfg):
def get_device():
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
else:
try:
if torch.backends.mps.is_available():
return "mps"
except:
return "cpu"
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=100,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = [file for file in path.glob("*.yml")]
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
@@ -130,31 +168,37 @@ def train(
config = choose_config(config)
# load the config from the yaml file
with open(config, "r") as f:
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k in kwargs:
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
cfg.gradient_accumulation_steps // cfg.world_size
)
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
@@ -163,26 +207,37 @@ def train(
cfg.fp16 = True
cfg.bf16 = False
validate_config(cfg)
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
)
@@ -200,7 +255,6 @@ def train(
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
)
if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -213,9 +267,15 @@ def train(
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if "inference" in kwargs:
if cfg.inference:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return
if "shard" in kwargs:
@@ -237,9 +297,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT,
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
logging.info("Starting trainer...")
@@ -252,20 +318,33 @@ def train(
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -1,43 +0,0 @@
#!/bin/bash
export WANDB_MODE=offline
export WANDB_CACHE_DIR=/workspace/data/wandb-cache
mkdir -p $WANDB_CACHE_DIR
mkdir -p /workspace/data/huggingface-cache/{hub,datasets}
export HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
export HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
export TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
export NCCL_P2P_DISABLE=1
nvidia-smi
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
gpu_indices=$(seq 0 $((num_gpus - 1)) | paste -sd "," -)
export CUDA_VISIBLE_DEVICES=$gpu_indices
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
apt-get update
apt-get install -y build-essential ninja-build vim git-lfs
git lfs install
pip3 install --force-reinstall https://download.pytorch.org/whl/nightly/cu117/torch-2.0.0.dev20230301%2Bcu117-cp38-cp38-linux_x86_64.whl --index-url https://download.pytorch.org/whl/nightly/cu117
if [ -z "${TORCH_CUDA_ARCH_LIST}" ]; then # only set this if not set yet
# this covers most common GPUs that the installed version of pytorch supports
# python -c "import torch; print(torch.cuda.get_arch_list())"
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
fi
# install flash-attn and deepspeed from pre-built wheels for this specific container b/c these take forever to install
mkdir -p /workspace/wheels
cd /workspace/wheels
curl -L -O https://github.com/OpenAccess-AI-Collective/axolotl/raw/wheels/wheels/deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
curl -L -O https://github.com/OpenAccess-AI-Collective/axolotl/raw/wheels/wheels/flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
pip install deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
pip install flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
pip install "peft @ git+https://github.com/huggingface/peft.git@main" --force-reinstall --no-dependencies
cd /workspace/
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
cd axolotl
pip install -e .[int4]
mkdir -p ~/.cache/huggingface/accelerate/
cp configs/accelerate/default_config.yaml ~/.cache/huggingface/accelerate/default_config.yaml

View File

@@ -1,7 +1,9 @@
from setuptools import setup, find_packages
"""setup.py for axolotl"""
from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", "r") as requirements_file:
with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
@@ -17,10 +19,10 @@ setup(
packages=find_packages(),
install_requires=install_requires,
extras_require={
"int4": [
"gptq": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"int4_triton": [
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"extras": [

View File

@@ -1,47 +1,76 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json
import sys
class FileReader:
"""
Reads a file and returns its contents as a string
"""
def read(self, file_path):
with open(file_path, "r") as file:
with open(file_path, encoding="utf-8") as file:
return file.read()
class FileWriter:
"""
Writes a string to a file
"""
def __init__(self, file_path):
self.file_path = file_path
def write(self, content):
with open(self.file_path, "w") as file:
with open(self.file_path, "w", encoding="utf-8") as file:
file.write(content)
class StdoutWriter:
"""
Writes a string to stdout
"""
def write(self, content):
sys.stdout.write(content)
sys.stdout.write("\n")
class JsonParser:
"""
Parses a string as JSON and returns the result
"""
def parse(self, content):
return json.loads(content)
class JsonlSerializer:
"""
Serializes a list of JSON objects into a JSONL string
"""
def serialize(self, data):
lines = [json.dumps(item) for item in data]
return "\n".join(lines)
class JsonToJsonlConverter:
"""
Converts a JSON file to JSONL
"""
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
self.file_reader = file_reader
self.file_writer = file_writer
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path, output_file_path):
def convert(
self, input_file_path, output_file_path
): # pylint: disable=unused-argument
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -1,10 +1,12 @@
"""Module containing Dataset functionality"""
import logging
from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
@@ -14,7 +16,14 @@ from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
class TokenizedPromptDataset(IterableDataset):
def __init__(
"""
Iterable dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
@@ -24,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset):
def __iter__(self):
iterator = iter(self.dataset)
count = 0
# Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
count += 1
except InvalidDataException:
pass
if count == 0:
raise RuntimeError("Expected at least one datapoint in dataset.")
# TODO this isn't the best since it can't interleave datasets
@@ -42,7 +55,7 @@ class ConstantLengthDataset(IterableDataset):
seq_length (int): Length of token sequences to return.
"""
def __init__(
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
@@ -82,10 +95,8 @@ class ConstantLengthDataset(IterableDataset):
else:
example_len = 0
if (
not example_len
or buffer_len + int(add_concat_token) + example_len
> self.seq_length
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
@@ -95,9 +106,8 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if (
labels.size() == input_ids.size()
and attention_mask.size() == input_ids.size()
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
@@ -108,15 +118,25 @@ class ConstantLengthDataset(IterableDataset):
logging.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
buffer_len = 0
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)

View File

@@ -1,17 +1,15 @@
"""Flash attention monkey patch for llama model"""
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
@@ -27,6 +25,7 @@ def forward(
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
@@ -74,7 +73,11 @@ def forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
@@ -82,35 +85,56 @@ def forward(
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,233 @@
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import math
from typing import Optional, Tuple
import torch
import transformers.models.llama.modeling_llama
from torch import nn
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
# pylint: skip-file
"""
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
"""
import torch
import transformers
import transformers.models.llama.modeling_llama
from einops import rearrange
class XposRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scale_base=2048,
use_xpos=True,
):
super().__init__()
self.max_seq_len_cached = max_position_embeddings
self.scale_base = scale_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
freqs = torch.einsum("i , j -> i j", t, inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("freqs_cached", freqs, persistent=False)
if not use_xpos:
self.register_buffer("scale", None)
self.register_buffer("scale_cached", torch.ones(1))
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
scale_cached = scale ** rearrange(power, "n -> n 1")
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
self.register_buffer("scale", scale, persistent=False)
self.register_buffer("scale_cached", scale_cached, persistent=False)
def forward(
self,
x,
seq_len,
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
self.inv_freq
)
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
self.register_buffer("freqs_cached", freqs)
if self.scale is None:
self.register_buffer(
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
)
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, "n -> n 1")
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
self.register_buffer("scale_cached", scale)
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
freqs = freqs[position_ids, :]
if scale.shape[-1] != 1:
scale = scale[position_ids, :]
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
return q_embed, k_embed
def replace_llama_rope_with_xpos_rope():
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb

View File

@@ -1,3 +1,5 @@
"""Module to load prompt strategies."""
import importlib
@@ -7,8 +9,8 @@ def load(strategy, tokenizer, cfg):
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
fn = getattr(m, load_fn)
return fn(tokenizer, cfg)
except:
pass
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
func = getattr(mod, load_fn)
return func(tokenizer, cfg)
except Exception: # pylint: disable=broad-exception-caught
return None

View File

@@ -1,21 +1,65 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Tuple
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value),
AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaConcisePrompter(AlpacaPrompter):
"""
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()
class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""
system_prompt = ""
system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for AlpacaQA
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"",
@@ -23,9 +67,49 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
)
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value),
class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for CamelAI datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["message_1"],
"",
prompt["message_2"],
)
def load_concise(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaConcisePrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy(
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -1,10 +1,21 @@
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.instruct),
AlpacaPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -0,0 +1,120 @@
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
prompt["system"],
)
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
(
instruction,
input, # pylint: disable=redefined-builtin
response,
system,
) = self.parse_instruction_fields(prompt)
user_prompt = next(
iter(
self.prompter.build_prompt_w_system(
system,
instruction,
input,
)
)
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_prompt
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Tokenizing strategy for OpenOrca datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["question"],
"",
prompt["response"],
prompt["system_prompt"],
)
def load(tokenizer, cfg):
return load_chat(tokenizer, cfg)
def load_instruct(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_chat(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -0,0 +1,67 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
# article, unanswerable_question, question, answer
def load_404(tokenizer, cfg):
return AlpacaMissingInfoContextPromptTokenizingStrategy(
AlpacaContextPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load(tokenizer, cfg):
return AlpacaContextPromptTokenizingStrategy(
AlpacaContextPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaContextPrompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
"""
system_prompt = (
"Use the following contextual information to concisely answer the question.\n"
)
system_no_input_prompt = (
"Use the following contextual information to concisely answer the question.\n"
)
class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenization Strategy to combine in-context article with a question and answer
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"] + "\n===\n" + prompt["question"],
"",
prompt["answer"],
)
class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy
):
"""
Tokenization Strategy to combine in-context article with a question that can't be answered
from the context and a default response to that effect
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
"",
"The context provided does not contain any information about your inquiry. "
"Therefore, I'm unable to answer your question based on the given context.",
)

View File

@@ -1,11 +1,18 @@
from typing import Union, Generator
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
from typing import Generator, Tuple, Union
import yaml
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Creative Answering
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
question = prompt["instruction"]
answer = prompt[
"revision"
@@ -18,6 +25,10 @@ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrat
class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Creative Critique
"""
user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
@@ -49,12 +60,16 @@ Question: {question}
Answer: {answer}
"""
def parse_instruction_fields(self, prompt) -> (str, str, str):
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump(
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
prompt["scores"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
prompt["critiques"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
evaluation = scores + critiques
question = prompt["instruction"]
@@ -67,6 +82,10 @@ Answer: {answer}
class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Creative Revise
"""
user_prompt = """Definitions:
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
@@ -81,12 +100,16 @@ Evaluation:
{evaluation}
"""
def parse_instruction_fields(self, prompt) -> (str, str, str):
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump(
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
prompt["scores"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
prompt["critiques"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
evaluation = scores + critiques
question = prompt["instruction"]
@@ -101,13 +124,19 @@ Evaluation:
class CreativePrompterBase:
"""
Base class for Creative Prompters
"""
system_prompt = ""
prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
input: Union[ # pylint: disable=redefined-builtin, unused-argument
None, str
] = None,
output: Union[None, str] = None,
) -> Generator[str, None, None]:
if self.system_prompt:
@@ -120,30 +149,51 @@ class CreativePrompterBase:
class CreativeAnswerPrompter(CreativePrompterBase):
"""
Prompter for Creative Answering
"""
system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
class CreativeCritiquePrompter(CreativePrompterBase):
"""
Prompter for Creative Critique
"""
system_prompt = ""
class CreativeRevisePrompter(CreativePrompterBase):
"""
Prompter for Creative Revise
"""
system_prompt = ""
def load_answer(tokenizer, cfg):
return CreativeAnsweringPromptTokenizingStrategy(
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
CreativeAnswerPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_critique(tokenizer, cfg):
return CreativeCritiquePromptTokenizingStrategy(
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
CreativeCritiquePrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_revise(tokenizer, cfg):
return CreativeRevisePromptTokenizingStrategy(
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
CreativeRevisePrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -1,29 +1,34 @@
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
import copy
import logging
from collections import defaultdict
from typing import Generator
from typing import Generator, List, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
PromptTokenizingStrategy,
parse_tokenized_to_result,
tokenize_prompt_default,
)
IGNORE_TOKEN_ID = -100
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
bot_prefix_token_ids = []
"""
Tokenizing strategy for Pygmalion.
"""
bot_prefix_token_ids: List[int] = []
def __init__(self, prompter, tokenizer, *args, **kwargs):
super().__init__(prompter, tokenizer)
super().__init__(prompter, tokenizer, *args, **kwargs)
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
self.bot_prefix_token_ids = res["input_ids"]
def tokenize_prompt(self, prompt):
result = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
result, current_len = tokenize_prompt_default()
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
role, message = part
if role == "system":
prefix = "<|system|>"
@@ -61,45 +66,29 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
logging.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: [])
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
class PygmalionPrompter:
"""
Prompter for Pygmalion.
"""
def __init__(self, *args, **kwargs):
pass
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
def build_prompt(
self, source, *args, **kwargs # pylint: disable=unused-argument
) -> Generator[Tuple[str, str], None, None]:
for msg in source:
yield msg["role"], msg["value"]

View File

@@ -0,0 +1,28 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
Tokenization strategy for asking bot to tell a joke and then explain why its funny
"""
# title, text, explanation
def get_conversation_thread(self, prompt):
title = "" if not prompt["title"] else prompt["title"] + " "
return [
{"from": "human", "value": "Tell me a joke."},
{"from": "gpt", "value": title + prompt["text"]},
{"from": "human", "value": "Why is that joke funny?"},
{"from": "gpt", "value": prompt["explanation"]},
]

View File

@@ -0,0 +1,67 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns

View File

@@ -1,24 +1,33 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""
import abc
import copy
import functools
import logging
from typing import Dict, List, Tuple, Union
from transformers import PreTrainedTokenizer
from axolotl.prompters import IGNORE_TOKEN_ID
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
class InvalidDataException(Exception):
pass
"""
Exception raised when the data is invalid
"""
class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies
"""
def __init__(
self,
prompter,
@@ -35,59 +44,21 @@ class PromptTokenizingStrategy(abc.ABC):
def tokenize_prompt(self, prompt):
pass
@functools.cache
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
@functools.cache
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
raise NotImplementedError
def tokenize_prompt(self, prompt):
instruction, input, response = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
)
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response):
return next(
iter(
self.prompter.build_prompt(
instruction,
input,
response,
)
)
)
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
@@ -111,8 +82,64 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
return result
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
(
instruction,
input, # pylint: disable=redefined-builtin
response,
) = self.parse_instruction_fields(prompt)
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
)
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_prompt
def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin
):
return next(
iter(
self.prompter.build_prompt(
instruction,
input,
response,
)
)
)
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Alpaca prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -121,7 +148,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Alpaca Multiple Choice prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
@@ -130,7 +161,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Jeopardy prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
prompt["category"],
@@ -139,7 +174,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for OpenAssistant prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["INSTRUCTION"],
"",
@@ -148,7 +187,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for SummarizeTLDR prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"],
"",
@@ -157,7 +200,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for GPTeacher prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -166,7 +213,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for NomicGPT4All prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["prompt"],
"",
@@ -175,28 +226,34 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> str:
return prompt["text"]
"""
Tokenizing strategy for Completion prompts.
"""
def tokenize_prompt(self, prompt):
instruction = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, None, None)
full_prompt = self._build_full_prompt(prompt["text"], None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response):
return next(iter(self.prompter.build_prompt(instruction)))
def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
"""
Tokenizing strategy for Reflection prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
(
instruction,
input,
input, # pylint: disable=redefined-builtin
output,
reflection,
corrected,
@@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
def _build_full_prompt(
self, instruction, input, output, reflection, corrected
): # pylint: disable=redefined-builtin
return next(
iter(
self.prompter.build_prompt(
@@ -236,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
)
)
def _tokenize(self, prompt, add_eos_token=True):
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
@@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
"""
Tokenizing strategy for Alpaca Reflection prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -268,20 +331,19 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
result = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
try:
for i, part in enumerate(
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
@@ -289,7 +351,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=True
part.strip(),
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
@@ -300,32 +364,39 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
res = self._tokenize(
part.strip(), add_eos_token=True, strip_bos_token=True
part.strip(),
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [assistant_token, *res["input_ids"]]
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:":
part = part[1] # Ignore the system role from preamble
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
logging.warning("unhandled role: " + part[0])
else:
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
logging.warning(f"unhandled role: {part[0]}")
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as e:
raise InvalidDataException(str(e))
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
@@ -349,3 +420,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["labels"] = result["input_ids"].copy()
return result
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
"""
Returns the default values for the tokenize prompt function
"""
result: Dict[str, List[int]] = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
return result, current_len
def parse_tokenized_to_result(
result: Dict[str, List[int]],
current_len: int,
res: Dict[str, List[int]],
labels: List[int],
pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]:
"""
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
"""
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result, current_len

View File

@@ -1,110 +1,155 @@
import copy
"""Module containing prompters"""
import dataclasses
import logging
from enum import auto, Enum
from typing import List, Tuple, Any, Union, Generator
from enum import Enum, auto
from typing import Generator, List, Optional, Tuple, Union
IGNORE_TOKEN_ID = -100
class PromptStyle(Enum):
instruct = "instruct"
chat = "chat"
"""
Enum for prompt styles
"""
INSTRUCT = "instruct"
CHAT = "chat"
class AlpacaPrompter:
"""
Base class for alpaca prompters
"""
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
prompt_style = None
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.instruct.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
self.match_prompt_style()
def match_prompt_style(self):
if self.prompt_style == PromptStyle.instruct.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
if self.prompt_style == PromptStyle.chat.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else:
res = self.prompt_no_input.format(instruction=instruction)
res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class UnpromptedPrompter(AlpacaPrompter):
"""
Prompter for alpaca no system prompt
"""
system_prompt = ""
system_no_input_prompt = ""
class JeopardyPrompter(AlpacaPrompter):
"""
Prompter for Jeopardy
"""
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
class MultipleChoiceExplainPrompter(AlpacaPrompter):
"""
Prompter for multiple choice explain
"""
system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning."
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
class MultipleChoiceConcisePrompter(AlpacaPrompter):
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
"""
Prompter for multiple choice concise
"""
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter):
prompt_no_input = (
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
)
"""
Prompter for summarize TLDR
"""
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
"""
Prompter for completion
"""
def build_prompt(
self, instruction: str, input=None, output=None
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter):
...
"""
Prompter for GPTeacher
"""
class NomicGPT4AllPrompter(AlpacaPrompter):
...
"""
Prompter for NomicGPT4All
"""
class ReflectAlpacaPrompter:
"""
Prompter for ReflectAlpaca
"""
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
@@ -120,7 +165,7 @@ class ReflectAlpacaPrompter:
self.match_prompt_style()
def match_prompt_style(self):
if self.prompt_style == PromptStyle.instruct.value:
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
@@ -131,7 +176,7 @@ class ReflectAlpacaPrompter:
)
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
self.response_split = "### Final Response:"
if self.prompt_style == PromptStyle.chat.value:
if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
@@ -146,7 +191,7 @@ class ReflectAlpacaPrompter:
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
@@ -159,14 +204,13 @@ class ReflectAlpacaPrompter:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(
output=output, reflection=reflection, corrected=corrected
output=output,
reflection=reflection,
corrected=corrected,
)
res = f"{res}{label}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum):
"""Different separator style."""
@@ -187,18 +231,18 @@ class Conversation:
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
sep2: Optional[str] = None
def get_prompt(self) -> Generator[str, None, None]:
seps = [self.sep, self.sep2]
preamble = self.system + seps[0]
yield preamble
for i, (role, message) in enumerate(self.messages):
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
logging.warning("role with empty message: " + role)
yield (role + ":",)
logging.warning(f"role with empty message: {role}")
yield (role + ":", "")
def copy(self):
return Conversation(
@@ -215,32 +259,35 @@ class Conversation:
self.messages.append([role, message])
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
"""
class ShareGPTPrompter:
def __init__(self, prompt_style=None):
if prompt_style != PromptStyle.chat.value:
raise Exception(
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
system: str = (
system_prompt
if system_prompt
else (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
self._conversation = Conversation(
system=system,
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":
source.pop(0)
@@ -250,7 +297,7 @@ class ShareGPTPrompter:
# also happens on the data splitting leaving empty conversations
raise IndexError
conv = conv_vicuna_v1_1.copy()
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try:
@@ -261,9 +308,9 @@ class ShareGPTPrompter:
):
# Skip the first one if it is not from human
source = source[1:]
except IndexError as e:
except IndexError as err:
# sometimes there is a bing or system chat
raise e
raise err
conv.messages = []
for j, sentence in enumerate(source):

View File

@@ -1,16 +1,20 @@
"""Callbacks for Trainer class"""
import os
from optimum.bettertransformer import BetterTransformer
from transformers import (
Seq2SeqTrainer,
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback):
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter"""
def on_save(
self,
args: TrainingArguments,
@@ -19,10 +23,47 @@ class SavePeftModelCallback(TrainerCallback):
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -1,42 +1,38 @@
"""Module containing data utilities"""
import functools
import logging
from hashlib import md5
from pathlib import Path
from typing import Union
from typing import List, Tuple, Union
from datasets import (
load_from_disk,
load_dataset,
IterableDataset,
Dataset,
concatenate_datasets,
DatasetDict,
)
import torch
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
ShareGPTPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
CompletionPromptTokenizingStrategy,
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter,
ShareGPTPrompter,
JeopardyPrompter,
CompletionPrompter,
MultipleChoiceExplainPrompter,
SummarizeTLDRPrompter,
MultipleChoiceConcisePrompter,
)
@@ -45,11 +41,13 @@ def load_tokenized_prepared_datasets(
) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
md5( # nosec
(
str(cfg.sequence_len)
+ "@"
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|"
+ tokenizer_name
).encode("utf-8")
@@ -65,10 +63,11 @@ def load_tokenized_prepared_datasets(
try:
if cfg.push_dataset_to_hub:
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token,
)
dataset = dataset["train"]
except:
except Exception: # pylint: disable=broad-except # nosec
pass
if dataset:
@@ -80,44 +79,80 @@ def load_tokenized_prepared_datasets(
else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
if cfg.seed:
seed = cfg.seed
else:
logging.info("No seed provided, using default seed of 42")
seed = 42
datasets = []
# pylint: disable=invalid-name
for d in cfg.datasets:
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
try:
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
load_dataset(
d.path,
streaming=True,
use_auth_token=use_auth_token,
)
ds_from_hub = True
except FileNotFoundError:
pass
# prefer local dataset, even if hub exists
if Path(d.path).exists():
ds: Dataset = load_dataset(
"json", data_files=d.path, streaming=False, split=None
)
local_path = Path(d.path)
if local_path.exists():
if local_path.is_dir():
ds = load_dataset(
d.path,
data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file():
ds = load_dataset(
"json",
data_files=d.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
if d.data_files:
ds: Dataset = load_dataset(
ds = load_dataset(
d.path,
streaming=False,
data_files=d.data_files,
use_auth_token=use_auth_token,
)
else:
ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token)
ds = load_dataset(
d.path,
streaming=False,
use_auth_token=use_auth_token,
)
else:
fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files
repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
)
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
if not ds:
raise Exception("unhandled dataset load")
raise ValueError("unhandled dataset load")
# support for using a subset of the data
if d.shards:
if "train" in ds:
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=d.shards, index=0
)
else:
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
@@ -218,13 +253,21 @@ def load_tokenized_prepared_datasets(
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
logging.error(
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
)
logging.info("tokenizing, merging, and shuffling master dataset")
samples = []
samples: List[int] = []
for d in datasets:
samples = samples + [i for i in d]
dataset = Dataset.from_list(samples).shuffle(seed=42)
samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0:
logging.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
@@ -242,8 +285,10 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
) -> (Dataset, Dataset):
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
@@ -256,13 +301,15 @@ def load_prepare_datasets(
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
md5(
md5( # nosec
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|"
+ tokenizer_name
).encode("utf-8")
@@ -282,10 +329,11 @@ def load_prepare_datasets(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token,
)
dataset = dataset["train"]
except:
except Exception: # pylint: disable=broad-except # nosec
pass
if dataset:
@@ -319,7 +367,7 @@ def load_prepare_datasets(
logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
dataset = Dataset.from_list(list(constant_len_dataset))
# filter out bad data
dataset = Dataset.from_list(
@@ -343,7 +391,8 @@ def load_prepare_datasets(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
f"{cfg.push_dataset_to_hub}/{ds_hash}",
private=True,
)
else:
dataset = load_tokenized_prepared_datasets(
@@ -355,11 +404,131 @@ def load_prepare_datasets(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
)
dataset = dataset.shard(
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
num_shards=cfg.dataset_shard_num,
index=cfg.dataset_shard_idx,
)
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
if cfg.val_set_size:
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples):
res = tokenizer(
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
logging.debug(len(ret["input_ids"]))
return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
return dataset

View File

@@ -1,3 +1,5 @@
"""Module containing the DictDefault class"""
from addict import Dict

View File

@@ -1,52 +1,53 @@
"""Module for models and model loading"""
import logging
import math
import os
from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import (
from optimum.bettertransformer import BetterTransformer
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
AutoConfig,
BitsAndBytesConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING:
from peft import PeftModel, PeftConfig
from axolotl.utils.dict import DictDefault
from transformers import PreTrainedTokenizer
from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(
base_model_config,
tokenizer_config,
tokenizer_type,
cfg,
):
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -54,7 +55,10 @@ def load_tokenizer(
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
"LlamaTokenizerFast",
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -62,8 +66,8 @@ def load_tokenizer(
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: val})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
@@ -71,39 +75,62 @@ def load_tokenizer(
def load_model(
base_model,
base_model_config,
model_type,
tokenizer,
cfg,
adapter="lora",
inference=False,
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
is_llama_derived_model = "llama" in base_model or (
cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
logging.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
if cfg.bf16:
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
MEM_TOKEN,
patch_llama_with_landmark_attn,
)
logging.info("patching with landmark attention")
patch_llama_with_landmark_attn()
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
)
logging.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16:
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
@@ -114,12 +141,21 @@ def load_model(
)
replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
except Exception as e:
logging.exception(e)
raise e
except Exception as err:
logging.exception(err)
raise err
try:
from peft import prepare_model_for_kbit_training
except ImportError:
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
@@ -130,7 +166,7 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
try:
if cfg.gptq and is_llama_derived_model:
if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
@@ -155,7 +191,7 @@ def load_model(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
except:
except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
@@ -168,9 +204,13 @@ def load_model(
else True,
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
@@ -203,21 +243,37 @@ def load_model(
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif model_type:
elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
):
config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
@@ -225,20 +281,21 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as e:
except Exception as err: # pylint: disable=broad-exception-caught
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
logging.exception(e)
logging.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=True if cfg.trust_remote_code is True else False,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -246,12 +303,23 @@ def load_model(
model.resize_token_embeddings(embeddings_len)
if (
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
and not cfg.gptq
and (load_in_8bit or cfg.load_in_4bit)
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
model, lora_config = load_adapter(model, cfg, adapter)
@@ -261,14 +329,14 @@ def load_model(
if cfg.gptq:
# Scales to half
logging.info("Fitting 4bit scales and zeros to half")
for n, m in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
type(m)
for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(module)
):
if hasattr(m, "is_v1_model") and m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
if hasattr(module, "is_v1_model") and module.is_v1_model:
module.zeros = module.zeros.half()
module.scales = module.scales.half()
module.bias = module.bias.half()
if (
torch.cuda.device_count() > 1
@@ -278,8 +346,8 @@ def load_model(
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
setattr(model, 'is_parallelizable', True)
setattr(model, 'model_parallel', True)
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)
requires_grad = []
for name, param in model.named_parameters(recurse=True):
@@ -289,6 +357,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling
return model, lora_config
@@ -308,11 +379,7 @@ def load_adapter(model, cfg, adapter):
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import (
AdaptionPromptConfig,
get_peft_model,
PeftModel,
)
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -325,7 +392,6 @@ def load_llama_adapter(model, cfg):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
device_map=cfg.device_map,
torch_dtype=torch.float16,
)
else:
@@ -357,11 +423,7 @@ def find_all_linear_names(bits, model):
def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import (
LoraConfig,
get_peft_model,
PeftModel,
)
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
@@ -391,8 +453,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
device_map=cfg.device_map,
# torch_dtype=torch.float16,
is_trainable=not cfg.inference,
)
else:
model = get_peft_model(model, lora_config)

View File

@@ -0,0 +1,173 @@
# pylint: skip-file
from typing import Any, List, Optional
import numba
import numpy as np
import torch.distributed as dist
from torch.utils.data import Sampler
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[int] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
m = (left + right) // 2
if ffd_check(lengths[start_index : start_index + m], c, n):
left = m
else:
right = m
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
class MultipackDistributedBatchSampler(Sampler):
"""Unpadded length sampling using Multipack.
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
batch_max_length: int,
lengths: List[int],
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
# Get rank
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.num_replicas = num_replicas
self.rank = rank
self.seed = seed
self.batch_max_length = batch_max_length
self.lengths = lengths
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def generate_batches(self, set_stats=False):
indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(
len(self.lengths)
)
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
c=self.batch_max_length,
n=self.num_replicas,
)
batches = [indices[batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches
def __iter__(self):
batches = self.generate_batches(set_stats=True)
return iter(batches)
def num_batches(self):
batches = self.generate_batches()
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -1,7 +1,16 @@
from torch.optim.lr_scheduler import LRScheduler
"""Module for custom LRScheduler class"""
import math
from functools import partial
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion
"""
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
@@ -19,7 +28,9 @@ class InterpolatingLogScheduler(LRScheduler):
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name
1 / (num_steps - 1)
)
super().__init__(optimizer, last_epoch)
def get_lr(self):
@@ -34,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
lrs = [self.max_lr for base_lr in self.base_lrs]
return lrs
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
def get_cosine_schedule_with_quadratic_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@@ -1,6 +1,10 @@
from termcolor import colored
"""Module for tokenization utilities"""
import logging
from termcolor import colored
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
@@ -17,7 +21,7 @@ def check_example_labels(example, tokenizer):
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = []
for i, (input_id, label_id, mask) in enumerate(
for _, (input_id, label_id, mask) in enumerate(
zip(input_ids, labels, attention_mask)
):
decoded_input_token = tokenizer.decode(input_id)
@@ -30,3 +34,5 @@ def check_example_labels(example, tokenizer):
logging.info(" ".join(colored_tokens))
logging.info("\n\n\n")
return " ".join(colored_tokens)

View File

@@ -1,28 +1,204 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import bitsandbytes as bnb
import numpy as np
import torch.cuda
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from torch.utils.data import Dataset
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.sampler import MultipackDistributedBatchSampler
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
IGNORE_LABEL_ID = -100
class OneCycleLRSchedulerTrainer(Trainer):
def _find_multiple(val1, val2):
return (-(val1 // -val2)) * val2
def batch_to_tensor(batch, pad_id=0, dtype=torch.long, loss_dtype=torch.bfloat16):
# Pad an unused item to reach multiple of 64, for faster GEMM
pad_cur_len = sum(list(batch["length"]))
pad_len = _find_multiple(pad_cur_len, 64) - pad_cur_len
if pad_len > 0:
assert pad_len < 64
batch["input_ids"].append([pad_id] * pad_len)
batch["labels"].append([pad_id] * pad_len)
batch["attention_mask"].append([0] * pad_len)
batch["length"].append(pad_len)
# seqlen
batch_lengths = torch.tensor(list(batch["length"]), dtype=torch.int32, device="cpu")
max_seqlen = torch.max(batch_lengths)
cu_seqlens = torch.nn.functional.pad(
batch_lengths.cumsum(-1, dtype=torch.int32), (1, 0)
)
# nz elements
nz_num = cu_seqlens[-1]
nz_input_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
nz_position_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
nz_shifted_label_ids = torch.zeros(
(nz_num,), dtype=dtype, pin_memory=True, device="cpu"
)
nz_shifted_loss_weights = torch.zeros(
(nz_num,), dtype=loss_dtype, pin_memory=True, device="cpu"
)
index = 0
for token_list, length, labels_list in zip(
batch["input_ids"], batch["length"], batch["labels"]
):
tokens = torch.tensor(token_list, dtype=dtype, device="cpu")
position_ids = torch.arange(length, dtype=dtype, device="cpu")
# Input IDs & shifted labels
# shifted_label_ids = torch.where(masks, tokens, IGNORE_LABEL_ID)
shifted_label_ids = labels_list
shifted_label_ids = torch.nn.functional.pad(
shifted_label_ids[1:], (0, 1), "constant", IGNORE_LABEL_ID
)
nz_input_ids[index : index + length] = tokens
nz_position_ids[index : index + length] = position_ids
nz_shifted_label_ids[index : index + length] = shifted_label_ids
# Loss weights
mask_count = sum(1 for label in labels_list[1:] if label != IGNORE_LABEL_ID)
loss_weight = (
1 / mask_count if mask_count > 0 else 0
) # Avoid division by zero for paddings
nz_shifted_loss_weights[index : index + length] = loss_weight
index += length
# inputs
return {
"max_seqlen": max_seqlen,
"cu_seqlens": cu_seqlens,
"nz_input_ids": nz_input_ids,
"nz_position_ids": nz_position_ids,
"nz_shifted_label_ids": nz_shifted_label_ids,
"nz_shifted_loss_weights": nz_shifted_loss_weights,
}
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=True,
metadata={"help": "Use sample packing for efficient training."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
lengths = np.array([len(sample["input_ids"]) for sample in self.train_dataset])
return MultipackDistributedBatchSampler(
batch_max_length=self.args.per_device_train_batch_size
* self.args.max_seq_length,
lengths=lengths,
seed=self.args.seed,
)
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
lengths = np.array([len(sample["input_ids"]) for sample in eval_dataset])
return MultipackDistributedBatchSampler(
batch_max_length=self.args.per_device_eval_batch_size
* self.args.max_seq_length,
lengths=lengths,
seed=self.args.seed,
)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
num_training_steps = num_training_steps
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
@@ -50,19 +226,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
save_steps = cfg.save_steps
eval_steps = cfg.eval_steps
training_arguments_kwargs = {}
if cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.gradient_checkpointing is not None:
if cfg.seed:
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing,
@@ -85,6 +263,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
# deepspeed
if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
@@ -97,7 +278,25 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
training_args = transformers.TrainingArguments(
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps * cfg.num_epochs,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
@@ -107,18 +306,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if save_steps else "epoch",
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps,
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
and cfg.val_set_size > 0
and save_steps is not None
and save_steps % eval_steps == 0
and cfg.load_in_8bit is not True
else False,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
@@ -140,7 +340,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if (
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq
and not "deepspeed" in training_arguments_kwargs
and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
@@ -206,9 +406,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True,
}
@@ -217,10 +423,30 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(model, tokenizer)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False,
num_proc=32,
)
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else transformers.Trainer
else AxolotlTrainer
)
trainer = trainer_cls(
model=model,

View File

@@ -1,7 +1,21 @@
"""Module for validating config files"""
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
logging.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
@@ -38,9 +52,59 @@ def validate_config(cfg):
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
logging.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
if cfg.push_to_hub_model_id:
raise ValueError(
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -1,3 +1,5 @@
"""Module for wandb utilities"""
import os
@@ -13,3 +15,5 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else:
os.environ["WANDB_DISABLED"] = "true"

12
tests/fixtures/alpaca/alpaca.json vendored Normal file
View File

@@ -0,0 +1,12 @@
[
{
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
"input": "Words: ['Hello', 'world'].",
"output": "['world', 'Hello']"
},
{
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
}
]

File diff suppressed because one or more lines are too long

View File

@@ -1,3 +1,6 @@
"""Module for testing DictDefault class"""
import unittest
import pytest
@@ -6,6 +9,10 @@ from axolotl.utils.dict import DictDefault
class DictDefaultTest(unittest.TestCase):
"""
Test DictDefault class
"""
def test_dict_default(self):
cfg = DictDefault(
{
@@ -41,7 +48,9 @@ class DictDefaultTest(unittest.TestCase):
}
)
cfg = cfg | DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
)
assert (
cfg.key_a.key_b == "value_b"
@@ -73,7 +82,7 @@ class DictDefaultTest(unittest.TestCase):
AttributeError,
match=r"'NoneType' object has no attribute 'another_random_key'",
):
cfg.random_key.another_random_key
cfg.random_key.another_random_key = "value"
def test_dict_shorthand_assignment(self):
"""

View File

@@ -0,0 +1,65 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
class TestPacking(unittest.TestCase):
"""
Test class for packing dataset sequences
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_resets_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0
if __name__ == "__main__":
unittest.main()

View File

@@ -1,3 +1,4 @@
"""Module for testing prompt tokenizers."""
import json
import logging
import unittest
@@ -5,14 +6,27 @@ from pathlib import Path
from transformers import AutoTokenizer
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompter
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
logging.basicConfig(level="INFO")
class TestPromptTokenizationStrategies(unittest.TestCase):
"""
Test class for prompt tokenization strategies.
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
@@ -23,11 +37,15 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
)
def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin:
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
data = fin.read()
conversation = json.loads(data)
with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin:
with open(
Path(__file__).parent / "fixtures/conversation.tokenized.json",
encoding="utf-8",
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
@@ -42,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
"output": "world!",
}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(3186)
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(6324)
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"system": "use cot",
"instruction": "hello!",
"output": "Hi! How can I help?",
}
example = strat.tokenize_prompt(sample)
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
assert example["input_ids"][3] == 11889 # USER
if __name__ == "__main__":
unittest.main()

View File

@@ -1,9 +1,21 @@
"""Module testing prompters"""
import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
from axolotl.prompters import (
AlpacaPrompter,
MultipleChoiceExplainPrompter,
PromptStyle,
UnpromptedPrompter,
)
class AlpacaPrompterTest(unittest.TestCase):
"""
Test AlpacaPrompter
"""
def test_prompt_style_w_none(self):
prompter = AlpacaPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
@@ -11,8 +23,10 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Instruction:" in res
def test_prompt_style_w_instruct(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "Below is an instruction" in res
assert "### Instruction:" in res
assert "### Input:" in res
@@ -29,8 +43,10 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "ASSISTANT:" not in res
def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "Below is an instruction" in res
assert "### Instruction:" not in res
assert "### Input:" not in res
@@ -46,4 +62,63 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "USER:" in res
assert "ASSISTANT:" in res
def test_system_prompt(self):
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt_w_system(
"use cot", "tell me a joke about the following", "alpacas"
)
)
assert "use cot" in res
assert res.startswith("use cot")
assert "### Instruction:" not in res
assert "### Input:" not in res
assert "alpacas" in res
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
class UnpromptedPrompterTest(unittest.TestCase):
"""
Test class for UnpromptedPrompter with no system prompts
"""
def test_prompt_style_w_none(self):
prompter = UnpromptedPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_instruct(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_chat(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "USER:" in res
assert "tell me a joke" in res
assert res.startswith("USER:")
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
"""
Test class for MultipleChoiceExplainPrompter
"""
def test_prompt_style_w_chat(self):
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
assert "USER:" in res
assert "choose one" in res
assert "Choose the answer that best answers the question." in res
assert "- A\n- B\n- C" in res

31
tests/test_tokenizers.py Normal file
View File

@@ -0,0 +1,31 @@
"""
Test cases for the tokenizer loading
"""
import unittest
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
def test_default_use_fast(self):
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__
if __name__ == "__main__":
unittest.main()

View File

@@ -1,12 +1,26 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):
"""
Test the validation module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
@@ -17,6 +31,17 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError):
validate_config(cfg)
def test_batch_size_unused_warning(self):
cfg = DictDefault(
{
"batch_size": 32,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message
def test_qlora(self):
base_cfg = DictDefault(
{
@@ -24,7 +49,7 @@ class ValidationTest(unittest.TestCase):
}
)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
}
@@ -33,7 +58,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
}
@@ -42,7 +67,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": False,
}
@@ -51,7 +76,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
}
@@ -67,7 +92,7 @@ class ValidationTest(unittest.TestCase):
}
)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
}
@@ -76,7 +101,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
}
@@ -85,7 +110,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
cfg = base_cfg | DictDefault(
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
}
@@ -112,3 +137,179 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_gradient_accumulations_or_batch_size(self):
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"batch_size": 1,
}
)
with pytest.raises(
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"batch_size": 1,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
}
)
validate_config(cfg)
def test_falcon_fsdp(self):
regex_exp = r".*FSDP is not supported for falcon models.*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
# Check for upper-case
cfg = DictDefault(
{
"base_model": "Falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
}
)
validate_config(cfg)
def test_mpt_gradient_checkpointing(self):
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "mosaicml/mpt-7b",
"gradient_checkpointing": True,
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{
"flash_optimum": True,
"adapter": "lora",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
"fp16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"flash_optimum": True,
"bf16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adam_epsilon": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adafactor",
"adam_beta1": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"adam_epsilon": 0.0001,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"optimizer": "adafactor",
}
)
validate_config(cfg)