Compare commits

..

108 Commits

Author SHA1 Message Date
Wing Lian
e6b78c1fca override the entire create_optimzier method 2024-03-19 23:19:56 -04:00
Wing Lian
a236f5eab5 add support for 4bit optimizers 2024-03-19 22:57:40 -04:00
Wing Lian
dd449c5cd8 support galore once upstreamed into transformers (#1409)
* support galore once upstreamed into transformers

* update module name for llama in readme and fix typing for all linear

* bump trl for deprecation fixes from newer transformers

* include galore as an extra and install in docker image

* fix optim_args type

* fix optim_args

* update dependencies for galore

* add galore to cicd dockerfile
2024-03-19 09:26:35 -04:00
NanoCode012
40a88e8c4a Feat: Add sharegpt multirole (#1137)
* feat(prompt): support multiple roles for sharegpt

* fix: add handling of empty role back

* feat: rebased and allowed more dynamic roles via config

* fix: variable

* chore: update message

* feat: add vicuna format

* fix: JSON serializable error

* fix: typing

* fix: don't remap for unknown keys

* fix: add roles to pydantic

* feat: add test

* chore: remove leftover print

* chore: remove leftover comment

* chore: remove print

* fix: update test to use chatml
2024-03-19 20:51:49 +09:00
Seungduk Kim
43bdc5d3de Add a config not to shuffle merged dataset (#1394) [skip ci]
* Add a config not to shuffle merged dataset

* Update README.md

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* invert the condition name

* update README

* info -> debug

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-03-19 20:51:00 +09:00
NanoCode012
b1e3e1b25f fix(config): passing gradient_checkpoint_kwargs (#1412)
* fix(config): change default use_reentrant to true

* Update trainer_builder.py

* fix: make sure to pass kwargs to enable checkpoint

* chore: lint
2024-03-19 12:57:43 +09:00
Wing Lian
2ea70ebbd8 ORPO (#1419)
* orpo trainer

* rl handling for orpo

* support for remove_unused_columns

* orpo fixes

* fix loader for orpo

* chore: lint

* fix default for remove_unused_columns

* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora

* better handling of system message for orpo

* revert system prompt changes for chat templtes

* no need for else condition

* split dataset parsing into it's own component
2024-03-18 13:10:00 -04:00
jbl
e8c8ea64b3 Update README.md (#1418)
Add Phorm AI Badge
2024-03-17 23:47:46 -04:00
NanoCode012
d485a08393 chore(script): remove redundant setting (#1411) 2024-03-16 21:10:38 +09:00
NanoCode012
f083aed2c7 Fix(readme): Improve README QuickStart info (#1408)
* Fix(readme): Improve README QuickStart info

* chore: add to toc
2024-03-16 21:10:22 +09:00
NanoCode012
868c33954d Feat(readme): Add instructions for Google GPU VM instances (#1410) 2024-03-16 21:10:05 +09:00
Wing Lian
8df7b888ff beta support for multipack with gemmoe: (#1402) 2024-03-14 15:52:23 -04:00
Sebastian Raschka
6366b0c212 Fix Gemma 7b qlora.yml (#1405) 2024-03-14 15:44:38 -04:00
Seungduk Kim
05bcc9ea56 Train parameters exclusively in specific ranges (#1390)
* Train parameters exclusively in specific ranges

* Fix the style and update docs

* Update yaml example
2024-03-14 11:05:42 -04:00
Chirag Jain
3bd8203c35 Don't disable existing loggers when configuring axolotl logging (#1395) 2024-03-14 11:05:21 -04:00
Hamel Husain
8b12468230 Add QLoRA + FSDP Docs (#1403)
* pre commit

* Update fsdp_qlora.md
2024-03-14 11:04:51 -04:00
Chirag Jain
0976781e15 Update ChatTemplate enum to include alpaca and gemma (#1396) 2024-03-13 11:06:02 -04:00
Wing Lian
8a82d2e0a4 add handling for argilla dpo-mix (#1397) 2024-03-12 17:17:10 -04:00
Wing Lian
4326520829 chore: lint (#1389) 2024-03-10 21:02:55 -04:00
Brian Fitzgerald
b7d8a7dc4d Add Glaive conversation format support (#1365)
* Add Glaive conversation format support

* fix black formatting errors

* Fix black and pylint formatting errors

* only set role_key_tool if provided in the dataset constructor

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* sharegpt test

* tokenizer test

* fix formatting

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-03-10 20:50:25 -04:00
Seungduk Kim
b0ee9ec734 Set gradient_clipping to auto in DeepSpeed configs (#1382) [skip ci] 2024-03-10 20:50:12 -04:00
David Baker
0bc114d2e1 Fix pydantic configuration for the max_memory input (#1385) [skip ci]
* Fix pydantic configuration for the max_memory input

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-03-10 20:50:04 -04:00
Wing Lian
7659c001aa support for rslora (#1387) [skip ci] 2024-03-10 20:49:45 -04:00
Wing Lian
3fd8093717 validation for fsdp and deepspeed (#1388) [skip ci]
* validation for fsdp and deepspeed

* make sure to return data
2024-03-10 20:49:25 -04:00
Wing Lian
9b6ee83a73 FDSP + QLoRA (#1378)
* wip qlora + fsdp fixes

* more fixes

* make sure to load the lora 🤦

* only setup quantized meta on non-zero rank:

* only run setup_quantized_peft_meta_for_training for qlora+fsdp

* more fixes for qlora+fsdp

* chore: lint

* add example yml

* support mistral too

* fix for model_type and add mixtral support too

* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp

* refactor for duplicate code
2024-03-08 14:31:01 -05:00
Wing Lian
638c2dafb5 JarvisLabs (#1372)
* add Jarvis cloud gpu and sponsorship

* whitespace
2024-03-07 10:47:32 -05:00
Wing Lian
58b0d4b0d8 update flash attention for gemma support: (#1368) 2024-03-06 10:08:54 -05:00
Hamel Husain
ed70a08348 add docs for input_output format (#1367) [skip ci]
* add docs

* add docs

* run linter
2024-03-06 09:09:49 -05:00
Wing Lian
0cfdb2c90c support for DoRA w/ PEFT (#1363) 2024-03-05 21:20:15 -05:00
Nicolas Rojas
37657473c8 Remove unsupported python version 3.9 from README (#1364) [skip ci] 2024-03-05 21:19:36 -05:00
Eric Hartford
e0f1895408 add starcoder2 (#1349)
* add starcoder2

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* chore: lint

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-03-05 19:49:17 -05:00
Sebastian Raschka
8984bf1722 Update tinyllama lora.yml to fix eval packing issue (#1362) 2024-03-05 14:36:29 -05:00
Wing Lian
2598c9f045 allow the sharegpt handler to also better handle datasets destined for openai finetuning (#1361)
* allow the sharegpt handler to also better handle datasets destined for openai finetuning

* make sure to support system role
2024-03-05 11:43:33 -05:00
Wing Lian
decb66e170 lora+ support (#1352)
* lora+ support

* optimizer should default to None

* include mit license
2024-03-05 07:29:23 -05:00
Wing Lian
4d09b42ee3 plain input/output prompt strategy w/o chat templates (#1346)
* plain input/output prompt strategy w/o chat templates

* disable duplicate code check

* make sure to add an eos/eot token to the end of the output so it will stop

* multi turn segement support and test
2024-03-04 16:25:16 -05:00
Chirag Jain
b5b44925ec Fix validation for early stopping (#1358) 2024-03-03 22:15:18 -05:00
NanoCode012
170d4d7092 chore: enable sample_packing for Gemma (#1351) 2024-03-01 21:56:22 -05:00
Wing Lian
00018629e7 run tests again on Modal (#1289) [skip ci]
* run tests again on Modal

* make sure to run the full suite of tests on modal

* run cicd steps via shell script

* run tests in different runs

* increase timeout

* split tests into steps on modal

* increase workflow timeout

* retry doing this with only a single script

* fix yml launch for modal ci

* reorder tests to run on modal

* skip dpo tests on modal

* run on L4s, A10G takes too long

* increase CPU and RAM for modal test

* run modal tests on A100s

* skip phi test on modal

* env not arg in modal dockerfile

* upgrade pydantic and fastapi for modal tests

* cleanup stray character

* use A10s instead of A100 for modal
2024-02-29 14:26:26 -05:00
Wing Lian
6b3b271925 fix for protected model_ namespace w pydantic (#1345) 2024-02-28 15:07:49 -05:00
Chirag Jain
3a5a2d2f34 Fix use_mlflow to be bool instead of str (#1344) 2024-02-28 12:58:29 -05:00
Wing Lian
6d4bbb877f deprecate py 3.9 support, set min pytorch version (#1343) [skip ci] 2024-02-28 12:58:05 -05:00
Wing Lian
0f985e12fe more fixes 20240228 (#1342) [skip ci]
* add missing evals_per_epoch setting

* more pydantic fixes

* more fixes

* move test from normalization to validation

* increase eval size for sample packing tests
2024-02-28 12:57:45 -05:00
Wing Lian
c1a7b3dd69 add gemma instruct chat template (#1341)
* add gemma instruct chat template

* support for chat tempalte strategy too
2024-02-27 17:20:01 -05:00
Ikko Eltociear Ashimine
2b9687f341 Update fastchat_conversation_turns.py (#1294) [skip ci]
seperated -> separated
2024-02-27 09:06:10 -05:00
Wing Lian
2c9c88b32a fix steps check for anneal on first cycle (#1316) 2024-02-27 08:56:08 -05:00
Hamel Husain
5265cd6b2c Update debugging.md (#1339) [skip ci] 2024-02-27 15:47:31 +09:00
NanoCode012
5be8b555a0 fix: checkpoint saving with deepspeed (#1321) 2024-02-27 15:46:44 +09:00
Maxime
0f6af36d50 Mps mistral lora (#1292) [skip ci]
* Lora example for Mistral on MPS backend

* Add some MPS documentation

* Update examples/mistral/lora-mps.yml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/mistral/lora-mps.yml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update README.md

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 22:39:57 -05:00
Wing Lian
3f69571943 more pydantic fixes (#1338) 2024-02-26 22:39:13 -05:00
nopperl
1e3d5305d3 Support user-defined prompt processing strategies for dpo (#1248)
* support user-defined prompt processing strategies for dpo

* interpret dict dataset types as user-defined

* fix lint errors

* setup pydantic config for validation of User defined DPO

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 18:49:34 -05:00
Maxime
16482796b0 add lion-pytorch optimizer (#1299) [skip ci]
* add lion-pytorch optimizer

* update pydantic to support lion optimizer

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 18:45:14 -05:00
Nathan Cooper
f30d062b48 Add StableLM 2 Example Scripts (#1327) [skip ci]
* Add StableLM examples and configurations

* Add FFT and LORA configuration files and modify readme with usage
2024-02-26 18:44:25 -05:00
Wing Lian
269c5436ea hotfix to exclude_unset from pydantic config when converting back to a dict (#1334) 2024-02-26 15:06:25 -05:00
Wing Lian
e7eed203d8 hotfix for missing outputs params (#1333) 2024-02-26 14:36:37 -05:00
Wing Lian
cf002312e0 hotfix for lora rank (#1332) 2024-02-26 14:28:43 -05:00
Wing Lian
7de912e097 hotfix for capabilities loading (#1331) 2024-02-26 14:24:28 -05:00
JohanWork
d75653407c ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]
* Add checkpoint logging to mlflow artifact registry

* clean up

* Update README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* update pydantic config from rebase

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 13:32:39 -05:00
NanoCode012
c6b01e0f4a chore: update readme to be more clear (#1326) [skip ci] 2024-02-26 13:32:13 -05:00
Wing Lian
cc3cebfa70 Pydantic 2.x cfg (#1239)
* WIP conversion to use pydantic for config validation

* wip, more fields, add capabilities

* wip

* update pydantic validation to match existing tests

* tweak requirements

* setup deprecated paams pydantic model

* more validations

* wrap up rest of the validations

* flesh out the rest of the options from the readme into pydantic

* fix model validators as class methods

remember to return in validator
missing return
add missing relora attributes
fix test for DictDefault change
fix sys template for mistral from fastchat change in PR 2872
fix test for batch size warning

* more missing attributes for cfg

* updates from PR feedback

* fix validation for datasets and pretrain datasets

* fix test for lora check
2024-02-26 12:24:14 -05:00
Wing Lian
5894f0e57e make mlflow optional (#1317)
* make mlflow optional

* fix xformers

don't patch swiglu if xformers not working
fix the check for xformers swiglu

* fix install of xformers with extra index url for docker builds

* fix docker build arg quoting
2024-02-26 11:41:33 -05:00
kallewoof
5cf226e177 Use yaml codeblock for config.yaml field (#1303) [skip ci] 2024-02-24 21:59:16 +09:00
NanoCode012
2ed52bd568 fix(readme): Clarify doc for tokenizer_config (#1323) [skip ci] 2024-02-24 21:55:04 +09:00
NanoCode012
a359579371 deprecate: pytorch 2.0.1 image (#1315) [skip ci]
* deprecate: pytorch 2.0.1 image

* deprecate from main image

* Update main.yml

* Update tests.yml
2024-02-22 11:39:47 +09:00
Wing Lian
2752d5f958 multipack for gemma (#1313)
* multipack for gemma

* chore: lint

* handle cache_position kwarg in updated llama modeling

* add position_ids to rotary embed call for updated llama modeling
2024-02-21 19:24:21 -05:00
Monk
9e300aca0c Adding Google's gemma Model (#1312) 2024-02-21 12:56:47 -05:00
NanoCode012
3d2cd804ae fix(readme): update inference md link (#1311) [skip ci] 2024-02-22 02:48:06 +09:00
Jared Palmer
6ab69ec5f8 Add instructions for playing with qlora model to colab example (#1290)
* Add instructions for playing with qlora model to colab example

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: JohanWork <39947546+JohanWork@users.noreply.github.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: JohanWork <39947546+JohanWork@users.noreply.github.com>
2024-02-22 02:46:27 +09:00
David Meikle
3c00f406d6 Allow load_best_model_at_end to be configured for early stopping on custom evaluation datasets (#1291)
* Allow load_best_model_at_end when using test_datasets and val_set_size is zero for custom evaluation datasets

* Fixed formatting following failed Lint check
2024-02-22 00:57:18 +09:00
NanoCode012
a7a9a1433a fix(examples): remove is_*_derived as it's parsed automatically (#1297) 2024-02-22 00:52:46 +09:00
Leonardo Emili
e2786cce6a Validation always happens on first step (#1300) 2024-02-22 00:52:24 +09:00
Leonardo Emili
5a5d47458d Add seq2seq eval benchmark callback (#1274)
* Add CausalLMBenchEvalCallback for measuring seq2seq performance

* Fix code for pre-commit

* Fix typing and improve logging

* eval_sample_packing must be false with CausalLMBenchEvalCallback
2024-02-13 08:24:30 -08:00
김진원
8430db22e2 Scheduler implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (#1273) 2024-02-12 21:23:28 -08:00
Wing Lian
4b997c3e1a allow the optimizer prune ratio for ReLoRA to be configurable (#1287)
* allow the optimizer prune ration for relora to be configurable

* update docs for relora

* prevent circular imports
2024-02-12 11:39:51 -08:00
Maxime
fac2d98c26 Add MPS support (#1264)
* add mps support

* linter stuff

* CI fixes

* install packaging for various tests

* Update setup.py

* Revert "install packaging for various tests"

This reverts commit 980e7aa44d.

* Revert "CI fixes"

This reverts commit 4609e3b166.

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-12 08:30:32 -05:00
Wing Lian
ea00dd0852 don't use load and push together (#1284) 2024-02-09 14:54:31 -05:00
Hamel Husain
b2a4cb4396 Update README.md (#1281) 2024-02-09 07:38:08 -08:00
Wing Lian
aaf54dc730 run the docker image builds and push on gh action gpu runners (#1218) 2024-02-09 10:32:54 -05:00
Hamel Husain
9bca7db133 add support for https remote yamls (#1277) 2024-02-08 20:02:17 -08:00
Hamel Husain
91cf4ee72c allow remote data paths (#1278)
* allow remote data paths

* add docs about public url

* only allow https

* better docs

* better docs
2024-02-08 15:02:35 -08:00
Wing Lian
1daecd161e copy edits (#1276) 2024-02-08 09:00:04 -05:00
Wing Lian
4a654b331e Add link to axolotl cloud image on latitude (#1275) 2024-02-08 08:50:11 -05:00
Wing Lian
5698943263 simplify haldning for newer multipack patches so they can be added in a single place (#1270) 2024-02-07 10:46:04 -05:00
Wing Lian
411293bdca contributor avatars (#1269) 2024-02-07 07:09:01 -08:00
Zac Brannelly
73f1bdaa15 Fix bug preventing model_kwargs being injected (#1262) 2024-02-07 09:38:35 -05:00
JohanWork
1c7ed26785 lock pytorch (#1247) [skip ci] 2024-02-06 07:48:26 -05:00
Philip May
13eea21f9b Add more save strategies for DPO training. (#1255)
* Set save_strategy and save_steps in HFDPOTrainerBuilder

* fix doublicate save_steps
2024-02-06 00:38:43 -05:00
Chirag Jain
1072f28874 Fix typo bloat16 -> bfloat16 (#1257) 2024-02-06 00:38:14 -05:00
Wing Lian
c7cf3810bd Pretrain transforms (#1261)
* wip for pretraining/iterable data with arbitrary prompt strategies

* more fixes, wip

* more fixes for custom pretraining

* iterable ds wrapper not needed

* remove extra features

* chore: lint

* update pretraning example yml

* fix order for partials

* fixup for tests
2024-02-06 00:37:03 -05:00
Wing Lian
8c2e05ade3 relora: magnitude pruning of the optimizer (#1245)
* magnitude pruning of the optimizer

* add alpaca chat template and fix relora patch

* fix handling of lora adapter for relora

* fix merge and save call

* fixes for 8-bit lora merge

* save intermediate checkpoint adapters

* auto merge

* fix eval check

* handle relora annealing

* fix anneal step logic

* chore: lint

* misx fix

* fix types

* Update tests/e2e/test_relora_llama.py

* check for safetensors saved from relora
2024-02-06 00:35:30 -05:00
NanoCode012
2d65f470d5 fix(model): apply gate fp32 only for mixtral (#1241)
* fix(model): apply gate fp32 only for mixtral

* Update src/axolotl/utils/models.py

* fix gate layer check

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-01 13:55:05 -05:00
Wing Lian
dfd188502a add contact info for dedicated support for axolotl [skip ci] (#1243) 2024-02-01 12:59:07 -05:00
Wing Lian
00568c1539 support for true batches with multipack (#1230)
* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
2024-02-01 10:18:42 -05:00
Wing Lian
c67fb71583 Peft deepspeed resume (#1227)
* import deepspeed integration

* monkeypatch peft adapater with deepspeed for resume from checkpoint

* fix patch

* fix patches attempt 2

* make sure to set lora_model_dir

* skip pylint for deepspeed.utils

* pick up upstream fix in transformers

* remove monkeypatch for deepspeed/peft fix

* no need to set the lora_model_dir on resume

* unset load_in_*bit when using quant config

* guard before del

* better handling of load_in* kwargs
2024-01-31 18:13:29 -05:00
DreamGenX
25e037fe2d Support for additional_special_tokens (#1221) [skip ci]
* Support for additional_special_tokens

* Support for additional_special_tokens. Adjust whitespace.

* Support for additional_special_tokens. Use correct quotes.

* Support for additional_special_tokens. Safe pop.

* Support for additional_special_tokens. nt.

* Support for additional_special_tokens. cfg.special_tokens may be None.

* add token if not in vocabulary when adding additional_special_tokens

* fix logic for copy/pasta

* bugfix for popping from config and tokenizer reload

* no need to add tokens manually now with previous bugfix

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-31 18:13:13 -05:00
Hamel Husain
52c83d30bf Update rlhf.md (#1237) [skip ci] 2024-01-31 17:27:35 -05:00
Wing Lian
d113331e9a add a helpful motd for cloud image (#1235) [skip ci] 2024-01-31 10:26:02 -05:00
Wing Lian
8f2b591baf set torch version to what is installed during axolotl install (#1234) 2024-01-31 08:47:34 -05:00
DreamGenX
5787e1a23f Fix and document test_datasets (#1228)
* Make sure test_dataset are used and treat val_set_size.

* Add test_datasets docs.

* Apply suggestions from code review

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-31 06:48:57 -05:00
xhedit
8608d8003e Fix typo (#1231) [skip ci] 2024-01-31 06:46:55 -05:00
Wing Lian
4cb7900a56 Peft lotfq (#1222)
* loftq support for lora

* fix loftq check

* update readme for loftq

* readability cleanup

* use peft main for loftq fixes, remove unnecessary special tokens

* remove unused test from older deprecation
2024-01-28 18:50:08 -05:00
Filippo Broggini
18f811978c FEAT: add tagging support to axolotl for DPOTrainer (#1209)
* Add AxolotlDPOTrainer

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-26 20:01:57 -05:00
Wing Lian
afb5dd9655 Update FUNDING.yml [skip ci] 2024-01-26 20:00:28 -05:00
Wing Lian
8da1633124 Revert "run PR e2e docker CI tests in Modal" (#1220) [skip ci] 2024-01-26 16:50:44 -05:00
Wing Lian
36d053f6f0 run PR e2e docker CI tests in Modal (#1217) [skip ci]
* wip modal for ci

* handle falcon layernorms better

* update

* rebuild the template each time with the pseudo-ARGS

* fix ref

* update tests to use modal

* cleanup ci script

* make sure to install jinja2 also

* kickoff the gh action on gh hosted runners and specify num gpus
2024-01-26 16:13:27 -05:00
JohanWork
af29d81f80 ADD: warning if hub_model_id ist set but not any save strategy (#1202)
* warning if hub model id set but no save

* add warning

* move the warning

* add test

* allow more public methods for tests for now

* fix tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-26 10:38:55 -05:00
Wing Lian
1b180034c7 ensure the tests use the same version of torch as the latest base docker images (#1215) [skip ci] 2024-01-26 10:38:30 -05:00
DreamGenX
62ca4a2b71 Respect sliding_window=None (#1214) 2024-01-26 07:43:37 -05:00
Igor Berlenko
5407ddd233 Update qlora.yml - remove max_packed_sequence_len (#1210) [skip ci] 2024-01-26 07:43:05 -05:00
154 changed files with 7416 additions and 1315 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1,6 +1,6 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: axolotl_ai # Replace with a single Ko-fi username

View File

@@ -59,6 +59,7 @@ body:
label: Config yaml
description: |
Please attach the config yaml!
render: yaml
- type: textarea
id: possible-solution

View File

@@ -7,16 +7,11 @@ jobs:
build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted
runs-on: axolotl-gpu-runner
strategy:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"

View File

@@ -17,6 +17,6 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

View File

@@ -9,21 +9,16 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121
cuda_version: 12.1.0
@@ -35,7 +30,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -56,27 +51,17 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
load: true
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: Push to Docker Hub
if: github.event_name != 'pull_request'
run: |
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
if [ -n "$latest_tag" ]; then
docker push "$latest_tag"
fi
build-axolotl-runpod:
needs: build-axolotl
@@ -85,11 +70,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
@@ -106,7 +86,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
@@ -33,7 +33,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.9", "3.10", "3.11"]
python_version: ["3.10", "3.11"]
timeout-minutes: 10
steps:
@@ -58,8 +58,8 @@ jobs:
docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, gpu, docker]
timeout-minutes: 30
runs-on: [self-hosted, modal]
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
@@ -69,44 +69,32 @@ jobs:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.1
pytorch: 2.1.2
num_gpus: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
- name: Install Python
uses: actions/setup-python@v5
with:
images: winglian/axolotl-tests
- name: Build Docker image
python-version: "3.10"
- name: Install Modal
run: |
# Set up build arguments
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
CUDA="${{ matrix.cuda }}"
PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image
docker build . \
--file ./docker/Dockerfile-tests \
--build-arg BASE_TAG=$BASE_TAG \
--build-arg CUDA=$CUDA \
--build-arg GITHUB_REF=$GITHUB_REF \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
--tag ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} \
--no-cache
- name: Unit Tests w docker image
python -m pip install --upgrade pip
pip install modal jinja2
- name: Update env vars
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: GPU Unit Tests w docker image
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
- name: GPU Unit Tests monkeypatched w docker image
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
modal run cicd.tests

5
.gitignore vendored
View File

@@ -167,3 +167,8 @@ cython_debug/
# WandB
# wandb creates a folder to store logs for training runs
wandb
# Runs
lora-out/*
qlora-out/*
mlruns/*

View File

@@ -1,5 +1,5 @@
[mypy]
plugins = pydantic.mypy
exclude = venv
[mypy-alpaca_lora_4bit.*]
@@ -32,6 +32,9 @@ ignore_missing_imports = True
[mypy-bitsandbytes]
ignore_missing_imports = True
[mypy-requests]
ignore_missing_imports = True
[mypy-datasets]
ignore_missing_imports = True

View File

@@ -31,6 +31,7 @@ repos:
additional_dependencies:
[
'types-PyYAML',
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5

298
README.md
View File

@@ -13,6 +13,9 @@ Features:
- Log results and optionally checkpoints to wandb or mlflow
- And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
<table>
<tr>
@@ -22,21 +25,25 @@ Features:
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Installation](#installation)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
- [LambdaLabs](#lambdalabs)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
- [Mac](#mac)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference)
- [Inference](#inference-playground)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- Advanced Topics
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
@@ -84,17 +91,18 @@ Features:
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
@@ -118,15 +126,20 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
```
## Installation
## Advanced Setup
### Environment
#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
```
Or run on the current files for development:
@@ -145,7 +158,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
```
It additionally:
@@ -160,7 +173,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
</details>
#### Conda/Pip venv
1. Install python >=**3.9**
1. Install python >=**3.10**
2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -179,9 +192,14 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### LambdaLabs
#### Bare Metal Cloud GPU
##### LambdaLabs
<details>
<summary>Click to Expand</summary>
@@ -189,11 +207,11 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
1. Install python
```bash
sudo apt update
sudo apt install -y python3.9
sudo apt install -y python3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
```
@@ -225,21 +243,46 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
```
</details>
##### GCP
<details>
<summary>Click to Expand</summary>
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
Make sure to run the below to uninstall xla.
```bash
pip uninstall -y torch_xla[tpu]
```
</details>
#### Windows
Please use WSL or Docker!
#### Mac
Use the below instead of the install method in QuickStart.
```
pip3 install -e '.'
```
More info: [mac.md](/docs/mac.md)
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
@@ -249,32 +292,33 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
#### Pretraining
- `completion`: raw corpus
```json
{"text": "..."}
```
Note: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```yaml
pretraining_dataset: # hf path only
```
#### Supervised finetuning
##### Instruction
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
<details>
<summary>See other formats</summary>
@@ -351,14 +395,37 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
```
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
</details>
##### Template-Free
- `input_output`: template-free prompt construction
```json
{"segments": [{"label": true|false, "text": "..."}]}
```
This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details.
##### Conversation
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
<details>
<summary>See other formats</summary>
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
@@ -374,6 +441,8 @@ Have dataset(s) in one of the following format (JSONL recommended):
</details>
Note: `type: sharegpt` opens a special config `conversation:` that enables conversions to many Conversation types. See dataset section under [all yaml options](#all-yaml-options).
#### How to add custom prompts
For a dataset that is preprocessed for instruction purposes:
@@ -395,12 +464,16 @@ datasets:
format: "[INST] {instruction} [/INST]"
no_input_format: "[INST] {instruction} [/INST]"
```
See full config options under [all yaml options](#all-yaml-options).
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
```yaml
- path: ...
```
### Config
@@ -414,22 +487,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- dataset
```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
type: alpaca
# huggingface repo with specific configuration/subset
datasets:
# huggingface repo with specific configuration/subset
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
datasets:
# huggingface repo with multiple named configurations/subsets
- path: bigcode/commitpackft
name:
- ruby
@@ -437,39 +506,42 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
conversation: chatml
conversation: chatml # default: vicuna_v1.1
# local
datasets:
# local
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
dataset:
# dataset with splits, but no train split
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
@@ -477,7 +549,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # qlora or leave blank for full finetune
adapter: lora # 'qlora' or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
@@ -486,9 +558,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- v_proj
```
<details>
<details id="all-yaml-options">
<summary>All yaml options (click me)</summary>
<summary>All yaml options (click to expand)</summary>
```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -500,8 +572,8 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
@@ -518,15 +590,16 @@ tokenizer_legacy:
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# (Internal use only)
# Used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_qwen_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
model_config:
overrides_of_model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
@@ -543,8 +616,6 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
@@ -580,9 +651,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
# Custom user instruction prompt
- path: repo
@@ -607,7 +682,22 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# use RL training: dpo, ipo, kto_pair
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
- path: /workspace/data/eval.jsonl
ds_type: json
# You need to specify a split. For "json" datasets the default split is called "train".
split: train
type: completion
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -627,7 +717,7 @@ dataset_processes: # defaults to os.cpu_count() if not set
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
hub_model_id: # private repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
@@ -696,10 +786,18 @@ lora_modules_to_save:
lora_fan_in_fan_out: false
peft:
# Configuration options for loftq initialization for LoRA
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
loftq_config:
loftq_bits: # typically 4 bits
# ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
@@ -715,6 +813,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -748,7 +847,8 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -767,7 +867,7 @@ group_by_length: false
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: false
# use_reentrant: true
# Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
@@ -777,14 +877,11 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim
lr_div_factor: # Learning rate div factor
# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
@@ -810,7 +907,26 @@ log_sweep_max_lr:
# - paged_adamw_8bit
# - paged_lion_32bit
# - paged_lion_8bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
# For Galore Optimizers the following optim_args are available
# rank: # type: int
# update_proj_gap # type: int
# scale # type: float
# proj_type: # type: str, default = std
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
optim_target_modules:
# - self_attn # for llama
# - mlp
# Specify weight decay
weight_decay:
# adamw hyperparams
@@ -956,6 +1072,9 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
@@ -1004,6 +1123,10 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
@@ -1065,7 +1188,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
@@ -1126,7 +1249,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
@@ -1135,9 +1258,11 @@ Having misalignment between your prompts during training and inference can cause
See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋♂️
## Need help? 🙋
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
Need dedicated support? Please contact us at [✉wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
## Badge ❤🏷️
@@ -1171,13 +1296,28 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
@@ -1206,4 +1346,6 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
#### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
---

39
cicd/Dockerfile.jinja Normal file
View File

@@ -0,0 +1,39 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

5
cicd/cicd.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

75
cicd/tests.py Normal file
View File

@@ -0,0 +1,75 @@
"""
modal application to run axolotl gpu tests in Modal
"""
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -16,6 +16,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -20,6 +20,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,6 +24,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,6 +24,7 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -2,7 +2,6 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -3,9 +3,10 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -7,8 +7,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"

View File

@@ -11,6 +11,7 @@ EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
@@ -18,6 +19,7 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh

View File

@@ -3,9 +3,10 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ARG PYTORCH_VERSION="2.1.2"
ARG GITHUB_REF="main"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -74,7 +74,6 @@ pip3 install -e '.[flash-attn,deepspeed]'
If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).
```bash
### Configuration

37
docs/fsdp_qlora.md Normal file
View File

@@ -0,0 +1,37 @@
# FDSP + QLoRA
## Background
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
Below, we describe how to use this feature in Axolotl.
## Usage
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> ![Tip]
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
## References
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
- Related HuggingFace PRs Enabling FDSP + QLoRA:
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.

BIN
docs/images/4d-mask.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 239 KiB

260
docs/input_output.md Normal file
View File

@@ -0,0 +1,260 @@
# Template-free prompt construction with the `input_output` format
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates (I usually don't!). This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ingored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```

18
docs/mac.md Normal file
View File

@@ -0,0 +1,18 @@
# Mac M series support
Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
Current support:
- [x] Support for all models
- [x] Full training of models
- [x] LoRA training
- [x] Sample packing
- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
- [ ] xformers
- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
- [ ] qlora
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -1,4 +1,11 @@
# Multipack
# Multipack (Sample Packing)
## Visualization of Multipack with Flash Attention
Because Flash Attention simply drops the attention mask, we do not need to
construct a 4d attention mask. We only need to concatenate the sequences into
a single batch and let flash attention know where each new sequence begins.
4k context, bsz =4,
each character represents 256 tokens
@@ -49,3 +56,18 @@ w packing ( note it's the same effective number of tokens per step, but a true b
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```
cu_seqlens:
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
## Multipack without Flash Attention
Multipack can still be achieved without Flash attention, but with lower packing
efficiency as we are not able to join multiple batches into a single batch due to
context length limits without flash attention. We can use either Pytorch's Scaled
Dot Product Attention implementation or native Pytorch attention implementation
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
to pack sequences together and avoid cross attention.
<img src="./images/4d-mask.png" alt="axolotl" width="800">

29
docs/optimizers.md Normal file
View File

@@ -0,0 +1,29 @@
# Optimizers
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
The goal of an optimizer is to minimize the loss function.
### Adam/AdamW Optimizers
```yaml
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1e-8
weight_decay: 0.0
```
### GaLore Optimizer
https://huggingface.co/papers/2403.03507
```yaml
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
optim_args:
rank: 128
update_proj_gap: 200
scale: 0.25
proj_type: std
optim_target_modules:
- mlp
- attn
```

View File

@@ -12,8 +12,8 @@ feedback. Various methods include, but not limited to:
### RLHF using Axolotl
[!IMPORTANT]
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
>[!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
@@ -34,6 +34,21 @@ datasets:
rl: ipo
```
#### ORPO
Paper: https://arxiv.org/abs/2403.07691
```yaml
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
```
#### Using local dataset files
```yaml
datasets:

View File

@@ -11,7 +11,6 @@ val_set_size: 0.05
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -43,6 +43,7 @@
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\""
@@ -176,6 +177,24 @@
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false

65
examples/gemma/qlora.yml Normal file
View File

@@ -0,0 +1,65 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -67,6 +66,3 @@ weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,5 +1,4 @@
base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_disable_exllama: true
model_type: AutoModelForCausalLM

View File

@@ -0,0 +1,69 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
peft:
loftq_config:
loftq_bits: 4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -57,7 +56,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
@@ -65,6 +64,3 @@ weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -65,6 +64,3 @@ weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,7 +2,6 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -61,7 +60,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,7 +1,6 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -49,7 +48,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -0,0 +1,79 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./lora-out
eval_sample_packing: false
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
sdp_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,74 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:

View File

@@ -16,12 +16,12 @@ output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
# - ^lm_head.weight$
# - ^model.embed_tokens.weight$[:32000]
# - model.layers.2[0-9]+.block_sparse_moe.gate
# - model.layers.2[0-9]+.block_sparse_moe.experts
# - model.layers.3[0-9]+.block_sparse_moe.gate
# - model.layers.3[0-9]+.block_sparse_moe.experts
model_config:
output_router_logits: true
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,7 +1,6 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -68,7 +67,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
@@ -58,7 +57,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
@@ -58,7 +57,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -0,0 +1,69 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,66 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,36 @@
# StableLM 2
This repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.
## Estimating GPU Requirements
| type | deepspeed | batch size | context length | vRAM GPU (GBs) |
|---------------|-----------|------------|----------------|----------------|
| full finetune | N/A | 1 | 4096 | ~21.5GBs |
| full finetune | zero2 | 1 | 4096 | ~20GBs |
| lora | N/A | 1 | 4096 | ~16.6GBs |
The above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).
This blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html
## Training
We have example scripts here for both full finetuning and lora using the popular alpaca dataset:
```shell
# preprocess the dataset
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml
```
Single GPU Training:
```shell
python -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json
# OR
python -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml
```
Multinode GPU Training with `accelerate`:
```shell
# make sure you've configured accelerate properly
accelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json
```

View File

@@ -0,0 +1,69 @@
base_model: bigcode/starcoder2-3b
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.2
output_dir: ./qlora
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_steps:
eval_table_size:
saves_per_epoch: 4
save_steps:
save_total_limit: 2
debug:
deepspeed:
weight_decay:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,64 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -16,6 +15,7 @@ output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora

View File

@@ -2,7 +2,6 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -12,6 +11,7 @@ max_steps: 200
pretraining_dataset:
path: c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,8 +1,7 @@
base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -29,7 +28,7 @@ num_epochs: 1
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1

View File

@@ -1,3 +1,4 @@
pre-commit
black
mypy
types-requests

View File

@@ -1,16 +1,18 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.7.1
transformers==4.37.0
peft==0.9.0
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
tokenizers==0.15.0
bitsandbytes>=0.41.1
bitsandbytes>=0.43.0
accelerate==0.26.1
deepspeed>=0.13.1
deepspeed==0.13.1
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
datasets>=2.15.0
flash-attn==2.3.3
flash-attn==2.5.5
sentencepiece
wandb
einops
@@ -20,14 +22,13 @@ hf_transfer
colorama
numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.0
evaluate==0.4.1
scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.34
fschat==0.2.36
gradio==3.50.2
tensorboard
@@ -38,4 +39,8 @@ s3fs
gcsfs
# adlfs
trl>=0.7.9
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
fastcore>=1.5.29
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
yacs

17
scripts/motd Normal file
View File

@@ -0,0 +1,17 @@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
```
cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
cd axolotl
pip install --no-deps -e .
```

View File

@@ -1,5 +1,7 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
@@ -16,6 +18,7 @@ def parse_requirements():
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -26,10 +29,25 @@ def parse_requirements():
_install_requires.append(line)
try:
torch_version = version("torch")
if torch_version.startswith("2.1."):
if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
else:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass
@@ -50,13 +68,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.3.3",
"flash-attn==2.5.5",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed>=0.13.1",
"deepspeed==0.13.1",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -65,5 +83,14 @@ setup(
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
},
)

View File

@@ -1,16 +1,19 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import gradio as gr
import requests
import torch
import yaml
@@ -20,6 +23,7 @@ from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -59,6 +63,52 @@ def print_axolotl_text_art(suffix=None):
print(ascii_art)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
@@ -164,6 +214,8 @@ def do_inference_gradio(
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
@@ -270,14 +322,14 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(config)
config = choose_config(Path(config))
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
@@ -290,7 +342,22 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
validate_config(cfg)
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
)
prepare_optim_env(cfg)

View File

@@ -3,6 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -23,7 +24,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
@@ -53,7 +54,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
if parsed_cfg.rl:
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -3,6 +3,7 @@ CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -25,7 +26,7 @@ def shard(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,11 +3,12 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple
from typing import Tuple, Union
import fire
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -24,10 +25,10 @@ from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
@@ -46,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
register_chatml_template()
if cfg.rl:
if cfg.rl and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,6 +6,7 @@ import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer

View File

View File

@@ -0,0 +1,55 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

View File

@@ -5,38 +5,52 @@ Builder for the training args and trainer
import abc
import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import lpmm
import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch import nn
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.core.trainers import OptimizerNames
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
@@ -49,8 +63,15 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
# monkeypatch so it accepts our custom optimizers
transformers.training_args.OptimizerNames = OptimizerNames
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
@@ -59,6 +80,26 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
@@ -82,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
@@ -106,6 +151,14 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
@@ -118,6 +171,9 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
@@ -135,6 +191,26 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
class AxolotlTrainer(Trainer):
@@ -151,13 +227,122 @@ class AxolotlTrainer(Trainer):
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@staticmethod
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optimizer_kwargs = {"lr": args.learning_rate}
adam_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}
if args.optim in [
OptimizerNames.LPMM_ADAMW_4BIT,
OptimizerNames.LPMM_ADAMW_4BIT_FUSED,
]:
optimizer_cls = lpmm.optim.AdamW
optimizer_kwargs.update(adam_kwargs)
if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED:
optimizer_kwargs.update({"fused": True})
return optimizer_cls, optimizer_kwargs
return Trainer.get_optimizer_cls_and_kwargs(
args,
model=model,
)
def create_optimizer(self):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
(
optimizer_cls,
optimizer_kwargs,
) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args)
if self.args.loraplus_lr_ratio:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
else:
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
@@ -192,6 +377,16 @@ class AxolotlTrainer(Trainer):
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
@@ -213,11 +408,19 @@ class AxolotlTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
batch_size=batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
@@ -227,11 +430,19 @@ class AxolotlTrainer(Trainer):
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
batch_size=batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
@@ -347,22 +558,111 @@ class AxolotlTrainer(Trainer):
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
def orpo_compute_custom_loss(self, logits, labels):
logits = logits.contiguous()
loss = 0.0
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return kwargs
# Flatten the tokens
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
dim=-1
)
return loss
def orpo_compute_logps(
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
):
# Get the shape of chosen_attention_mask[:, :-1]
chosen_shape = chosen_attention_mask[:, :-1].shape
# Calculate the padding size
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
# Pad prompt_attention_mask with zeros to match the desired shape
prompt_attention_mask_padded = torch.nn.functional.pad(
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
)
# Perform the subtraction operation
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
per_token_logps = torch.gather(
logits[:, :-1, :].log_softmax(-1),
dim=2,
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
).squeeze(2)
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
dtype=torch.float64
) / mask.sum(dim=1).to(dtype=torch.float64)
def orpo_compute_loss(self, model, inputs, return_outputs=False):
outputs_neg = model(
**{
"input_ids": inputs["rejected_input_ids"],
"attention_mask": inputs["rejected_attention_mask"],
"labels": inputs["rejected_labels"],
},
output_hidden_states=True,
)
outputs_pos = model(
**{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": inputs["labels"],
},
output_hidden_states=True,
)
# Calculate NLL loss
pos_loss = self.orpo_compute_custom_loss(
logits=outputs_pos.logits, labels=inputs["input_ids"]
)
# Calculate Log Probability
pos_prob = self.orpo_compute_logps(
prompt_attention_mask=inputs["prompt_attention_mask"],
chosen_inputs=inputs["input_ids"],
chosen_attention_mask=inputs["attention_mask"],
logits=outputs_pos.logits,
)
neg_prob = self.orpo_compute_logps(
prompt_attention_mask=inputs["prompt_attention_mask"],
chosen_inputs=inputs["rejected_input_ids"],
chosen_attention_mask=inputs["rejected_attention_mask"],
logits=outputs_neg.logits,
)
# Calculate log odds
log_odds = (pos_prob - neg_prob) - (
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
)
sig_ratio = torch.nn.functional.sigmoid(log_odds)
ratio = torch.log(sig_ratio)
# Calculate the Final Loss
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
dtype=torch.bfloat16
)
metrics = {}
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
self.store_metrics(metrics, train_eval="train")
return (loss, outputs_pos) if return_outputs else loss
@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
@@ -370,12 +670,82 @@ class AxolotlTrainer(Trainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()
if self.args.qlora is False:
return res
# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']
if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
return res
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -459,10 +829,14 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
@@ -471,6 +845,24 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -577,7 +969,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow:
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
@@ -597,6 +993,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
@@ -655,15 +1056,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
@@ -718,7 +1118,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.val_set_size == 0:
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
@@ -745,6 +1150,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
@@ -805,7 +1212,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and self.cfg.val_set_size > 0
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
@@ -826,6 +1236,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
@@ -836,12 +1262,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.flash_attention is not True
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
@@ -850,31 +1282,70 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
trainer_kwargs = {}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
@@ -944,7 +1415,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
@@ -1041,13 +1517,21 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
"use_reentrant": False
}
# set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
logging_first_step=True,
@@ -1076,7 +1560,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer = DPOTrainer(
dpo_trainer = AxolotlDPOTrainer(
self.model,
self.model_ref,
args=training_args,

View File

@@ -0,0 +1,40 @@
"""module for trainer helpers like OptimizerNames"""
from transformers.utils import ExplicitEnum
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
"""
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,

View File

@@ -30,6 +30,7 @@ class ColorfulFormatter(Formatter):
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

View File

@@ -0,0 +1,46 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
if self.auto_collation:
if (
hasattr(self.dataset, "__getitems__")
and self.dataset.__getitems__
):
data[i] = self.dataset.__getitems__(possibly_batched_index_)
else:
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
else:
data[i] = self.dataset[possibly_batched_index_]
else:
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def patch_fetchers():
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
def patched_worker_loop(*args, **kwargs):
patch_fetchers()
return _worker_loop(*args, **kwargs)
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
patch_fetchers()

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -106,7 +106,7 @@ def get_turns( # pylint: disable=too-many-return-statements
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(

View File

@@ -44,6 +44,18 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_swiglu_available() -> bool:
from xformers.ops.common import get_xformers_operator
try:
get_xformers_operator("swiglu_packedw")()
return True
except RuntimeError as exc:
if "No such operator xformers::swiglu_packedw " in str(exc):
return False
return True
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
@@ -275,7 +287,9 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -425,7 +439,9 @@ def flashattn_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -688,6 +704,9 @@ def llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions

View File

@@ -1,142 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -5,38 +5,11 @@ from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(

View File

@@ -0,0 +1,26 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)

View File

@@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask(
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None:
if attention_mask is None or sliding_window is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
@@ -151,7 +151,7 @@ def flashattn_forward(
)
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
getattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)

View File

@@ -2,9 +2,6 @@
Patches to support multipack for mixtral
"""
import torch
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def patch_mixtral_moe_forward_zero3() -> None:
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -0,0 +1,61 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -4,14 +4,16 @@ import json
import logging
import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Sequence
from typing import Dict, List, Sequence, Union
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
@@ -23,23 +25,51 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.distributed import barrier, is_main_process
LOG = logging.getLogger("axolotl.relora")
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
@torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio):
tensor_magnitude = torch.abs(tensor)
threshold = torch.quantile(
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
).to(dtype=tensor.dtype)
if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
mask = tensor_magnitude > threshold
tensor.mul_(mask.to(dtype=tensor.dtype))
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
n_zeros = 0
n_total = 0
optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
class ReLoRACallback(TrainerCallback):
@@ -97,6 +127,25 @@ class ReLoRACallback(TrainerCallback):
"relora",
)
if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
lora_params = [
n
for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n
]
model.save_pretrained(
os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
model,
@@ -107,7 +156,12 @@ class ReLoRACallback(TrainerCallback):
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(optimizer)
reset_optimizer(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
)
if self.quantized:
self.last_full_model = checkpoint_folder
@@ -197,11 +251,13 @@ class ReLoRAScheduler(LRScheduler):
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
@@ -210,10 +266,20 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
if step < self.relora_steps - self.warmup_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
per_relora_progress = step % self.relora_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
@@ -238,7 +304,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter = layer.active_adapter
adapter: Union[List[str], str] = layer.active_adapter
if isinstance(adapter, list):
if len(adapter) > 1:
raise ValueError("unhandled relora for multiple adapters")
adapter = adapter[0]
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
@@ -248,7 +318,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
* layer.scaling[adapter]
)
return layer.get_delta_weight().to(device)
raise ValueError("unhandled lora layer type")
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
@@ -273,9 +343,9 @@ def update_weights(
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
@@ -286,7 +356,9 @@ def update_weights(
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
target.weight.data = (
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
)
else:
target.weight.data = new_weight.to(device)
@@ -304,14 +376,17 @@ def merge_and_save(
if not quantized:
for module_name, target in modules.items():
update = target.get_delta_weight(target.active_adapter).detach()
active_adapter = target.active_adapter
if isinstance(active_adapter, list):
active_adapter = active_adapter[0]
update = target.get_delta_weight(active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
return
os.makedirs(model_dst, exist_ok=True)
@@ -363,6 +438,7 @@ def merge_and_save(
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
barrier()
del in_tensors
del out_tensors
torch.cuda.empty_cache()

View File

@@ -1,8 +1,15 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -89,7 +96,6 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
@torch.jit.script
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -135,7 +141,18 @@ def get_cu_seqlens_from_pos_ids(position_ids):
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
# Find the maximum value across all tensors
max_value = max(t.max() for t in results)
# Find the length of the longest tensor
max_length = max(t.size(0) for t in results)
# Pad each tensor to the same length and collect them in a list
padded_results = [
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
]
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
@@ -149,3 +166,62 @@ def set_module_name(model, name, value):
child_name = name
setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)

View File

@@ -0,0 +1,20 @@
"""
module for base dataset transform strategies
"""
import importlib
import logging
LOG = logging.getLogger("axolotl")
def load(strategy, cfg, module_base=None, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -0,0 +1,78 @@
"""
HF Chat Templates prompt strategy
"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

Some files were not shown because too many files have changed in this diff Show More