Compare commits

...

82 Commits

Author SHA1 Message Date
Wing Lian
e9c3a2aec0 add missing dunder-init for monkeypatches and add tests for install from sdist (#2085)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 121, 12.1.1, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* add missing dunder-init for monkeypatches and add tests for install from sdist

* fix gha name

* reduce matrix for sdist test
2024-11-19 12:43:30 -05:00
Wing Lian
02ca3f93b0 set manifest and fix for source dist (#2084)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 121, 12.1.1, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2024-11-19 11:31:56 -05:00
Wing Lian
5f6f9186e4 make sure action has permission to create release (#2083) [skip ci] 2024-11-19 10:43:02 -05:00
Wing Lian
6679e20f47 release version 0.5.1 (#2082) 2024-11-19 10:35:59 -05:00
Wing Lian
ec59d4cb83 remove deprecated extra metadata kwarg from pydantic Field (#2081) [skip ci] 2024-11-19 10:30:10 -05:00
Wing Lian
a77c8a71cf fix brackets on docker ci builds, add option to skip e2e builds [skip e2e] (#2080) [skip ci] 2024-11-19 10:29:31 -05:00
Wing Lian
775311f98f add optimizer step to prevent warning in tests (#1502) [skip ci]
* add optimizer step to prevent warning in tests

* add optimizer step to warmup as well
2024-11-19 10:19:03 -05:00
NanoCode012
f007c38e49 Feat: Drop long samples and shuffle rl samples (#2040) [skip ci]
* feat: LOG warn if samples are dropped due to seq length

* feat: add drop long samples for RL

* feat: add ipo

* fix: remove num_proc for map as subprocesses are prone to die

* feat: shuffle rl dataset

* fix: support preprocess for kto

* chore: use set instead of list

* feat: add simpo
2024-11-19 10:18:24 -05:00
Wing Lian
d9b71edf84 bump transformers for fsdp-grad-accum fix, remove patch (#2079) 2024-11-19 02:23:09 -05:00
Wing Lian
c07bd2fa65 Readme updates v2 (#2078)
* update readme logos

* use full logo

* Fix svgs

* add srcset

* resize svgs to match

* Rename file

* align badges center
2024-11-18 14:58:03 -05:00
Wing Lian
ed079d434a static assets, readme, and badges update v1 (#2077) 2024-11-18 13:59:32 -05:00
Wing Lian
8403c67156 don't build bdist (#2076) [skip ci] 2024-11-18 12:36:03 -05:00
Wing Lian
9871fa060b optim e2e tests to run a bit faster (#2069) [skip ci]
* optim e2e tests to run a bit faster

* run prequant w/o lora_modules_to_save

* use smollm2
2024-11-18 12:35:31 -05:00
Wing Lian
70cf79ef52 upgrade autoawq==0.2.7.post2 for transformers fix (#2070)
* point to upstream autoawq for transformers fix

* use autoawq 0.2.7 release

* test wheel for awq

* try different format for wheel def

* autoawq re-release

* Add intel_extension_for_pytorch dep

* ipex gte version

* forcefully remove intel-extension-for-pytorch

* add -y option to pip uninstall for ipex

* use post2 release for autoawq and remove uninstall of ipex
2024-11-18 11:53:37 -05:00
Wing Lian
c06b8f0243 increase worker count to 8 for basic pytests (#2075) [skip ci] 2024-11-18 11:52:35 -05:00
Chirag Jain
0c8b1d824a Update get_unpad_data patching for multipack (#2013)
* Update `get_unpad_data` patching for multipack

* Update src/axolotl/utils/models.py

* Update src/axolotl/utils/models.py

* Add test case

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-11-15 20:35:50 -05:00
NanoCode012
fd70eec577 fix: loading locally downloaded dataset (#2056) [skip ci] 2024-11-15 20:35:26 -05:00
Wing Lian
d42f202046 Fsdp grad accum monkeypatch (#2064) 2024-11-15 19:11:04 -05:00
Wing Lian
0dabde1962 support for schedule free and e2e ci smoke test (#2066) [skip ci]
* support for schedule free and e2e ci smoke test

* set default lr scheduler to constant in test

* ignore duplicate code

* fix quotes for config/dict
2024-11-15 19:10:14 -05:00
Wing Lian
15f1462ccd support passing trust_remote_code to dataset loading (#2050) [skip ci]
* support passing trust_remote_code to dataset loading

* add doc for trust_remote_code in dataset config
2024-11-15 19:09:48 -05:00
Wing Lian
521e62daf1 remove the bos token from dpo outputs (#1733) [skip ci]
* remove the bos token from dpo outputs

* don't forget to fix prompt_input_ids too

* use processing_class instead of tokenizer

* fix for processing class
2024-11-15 19:09:20 -05:00
Wing Lian
c16ec398d7 update to be deprecated evaluation_strategy (#1682) [skip ci]
* update to be deprecated evaluation_strategy and c4 dataset

* chore: lint

* remap eval strategy to new config and add tests
2024-11-15 19:09:00 -05:00
Wing Lian
2f20cb7ebf upgrade datasets==3.1.0 and add upstream check (#2067) [skip ci] 2024-11-15 19:08:38 -05:00
Wing Lian
71d4030b79 gradient accumulation tests, embeddings w pad_token fix, smaller models (#2059)
* add more test cases for gradient accumulation and fix zero3

* swap out for smaller model

* fix missing return

* fix missing pad_token in config

* support concurrency for multigpu testing

* cast empty deepspeed to empty string for zero3 check

* fix temp_dir as fixture so parametrize works properly

* fix test file for multigpu evals

* don't use default

* don't use default for fsdp_state_dict_type

* don't use llama tokenizer w smollm

* also automatically cancel multigpu for concurrency
2024-11-14 12:59:00 -05:00
Wing Lian
f3a5d119af fix env var extraction (#2043) [skip ci] 2024-11-14 12:58:06 -05:00
Wing Lian
ba219b51a5 fix duplicate base build (#2061) [skip ci] 2024-11-14 10:31:19 -05:00
Wing Lian
5be8e13d35 make sure to add tags for versioned tag on cloud docker images (#2060) 2024-11-14 10:24:49 -05:00
Wing Lian
2d7830fda6 upgrade to flash-attn 2.7.0 (#2048) 2024-11-14 06:59:25 -05:00
Wing Lian
5e98cdddac Grokfast support (#1917) 2024-11-13 17:10:36 -05:00
Sunny Liu
1d7aee0ad2 ADOPT optimizer integration (#2032) [skip ci]
* adopt integration

* stuff

* doc and test for ADOPT

* rearrangement

* fixed formatting

* hacking pre-commit

* chore: lint

* update module doc for adopt optimizer

* remove un-necessary example yaml for adopt optimizer

* skip test adopt if torch<2.5.1

* formatting

* use version.parse

* specifies required torch version for adopt_adamw

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-11-13 17:10:17 -05:00
Wing Lian
659ee5d723 don't cancel the tests on main automatically for concurrency (#2055) [skip ci] 2024-11-13 17:07:41 -05:00
Sunny Liu
342935cff3 Update unsloth for torch.cuda.amp deprecation (#2042)
* update deprecated unsloth tirch cuda amp  decorator

* WIP fix torch.cuda.amp deprecation

* lint

* laxing torch version requirement

* remove use of partial

* remove use of partial

* lint

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
2024-11-13 15:17:34 -05:00
Wing Lian
c5eb9ea2c2 fix push to main and tag semver build for docker ci (#2054) 2024-11-13 14:04:28 -05:00
Wing Lian
f2145a3ccb add default torch version if not installed, and support for xformers new wheels (#2049) 2024-11-13 13:16:47 -05:00
Wing Lian
010d0e7ff3 retry flaky test_packing_stream_dataset test that timesout on read (#2052) [skip ci] 2024-11-13 13:16:16 -05:00
Wing Lian
01881c3113 make sure to tag images in docker for tagged releases (#2051) [skip ci]
* make sure to tag images in docker for tagged releases

* fix tag event
2024-11-13 13:15:49 -05:00
Wing Lian
0e8eb96e07 run pypi release action on tag create w version (#2047) 2024-11-13 10:21:48 -05:00
NanoCode012
4e1891b12b feat: upgrade to liger 0.4.1 (#2045) 2024-11-13 10:07:24 -05:00
NanoCode012
28924fc791 feat: cancel ongoing tests if new CI is triggered (#2046) [skip ci] 2024-11-13 10:06:59 -05:00
NanoCode012
8c480b2804 fix: inference not using chat_template (#2019) [skip ci] 2024-11-13 10:06:41 -05:00
Oliver Molenschot
a4b1cc6df0 Add example YAML file for training Mistral using DPO (#2029) [skip ci]
* Add example YAML file for training Mistral using DPO

* chore: lint

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update mistral-dpo.yml 

Adding qlora and removing role-related data (unecessary)

* Rename mistral-dpo.yml to mistral-dpo-qlora.yml

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-11-13 10:06:25 -05:00
NanoCode012
7b78a31593 feat: print out dataset length even if not preprocess (#2034) [skip ci] 2024-11-13 10:06:00 -05:00
Wing Lian
810ebc2c0e invert the string in string check for p2p device check (#2044) 2024-11-12 23:20:47 -05:00
Wing Lian
ad435a3b09 add P2P env when multi-gpu but not the full node (#2041)
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-11-12 17:58:26 -05:00
NanoCode012
9f1cf9b17c fix: handle sharegpt dataset missing (#2035)
* fix: handle sharegpt dataset missing

* fix: explanation

* feat: add test
2024-11-12 12:51:37 +07:00
Wing Lian
3931a42763 change deprecated modal Stub to App (#2038) 2024-11-11 15:10:34 -05:00
NanoCode012
dc8f9059f7 feat: add metharme chat_template (#2033) [skip ci]
* feat: add metharme chat_template

* fix: add eos token
2024-11-11 15:09:58 -05:00
Wing Lian
234e94e9dd replace references to personal docker hub to org docker hub (#2036) [skip ci] 2024-11-11 15:09:29 -05:00
Wing Lian
f68fb71005 update actions version for node16 deprecation (#2037) [skip ci]
* update actions version for node16 deprecation

* update pre-commit/action to use 3.0.1 for actions/cache@v4 dep

* update docker/setup-buildx-action too to v3
2024-11-11 15:09:11 -05:00
Wing Lian
9bc3ee6c75 add axolotlai docker hub org to publish list (#2031)
* add axolotlai docker hub org to publish list

* fix to use latest actions docker metadata version

* fix list in yaml for expected format for action

* missed a change
2024-11-11 09:48:19 -05:00
Wing Lian
d356740ffa move deprecated kwargs from trainer to trainingargs (#2028) 2024-11-10 12:45:47 -05:00
Wing Lian
e4af51eb66 remove direct dependency on fused dense lib (#2027)
Some checks failed
publish pypi / Upload release to PyPI (push) Has been cancelled
2024-11-08 14:48:04 -05:00
Wing Lian
e20b15bee3 make publish to pypi manually dispatchable as a workflow (#2026) [skip ci] 2024-11-08 14:18:16 -05:00
Wing Lian
d4796cb645 increment version to 0.5.0 for next release (#2025) [skip ci] 2024-11-08 14:02:25 -05:00
Wing Lian
fd3b80716a remove fastchat and sharegpt (#2021)
* remove fastchat and sharegpt

* remove imports

* remove more fastchat imports

* chore: remove unused functions

* feat: remove sharegpt and deprecate from docs

* chore: remove unused sharegpt checks

* fix: remove sharegpt type from tests

* feat: add sharegpt deprecation error

* feat: update readme

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-08 13:45:49 -05:00
Sunny Liu
3265b7095e Add weighted optimisation support for trl DPO trainer integration (#2016)
* trlv0.12.0  integration

* update trl version requirements

* linting

* commenting out

* trl version requirement
2024-11-08 11:29:11 -05:00
Wing Lian
3cb2d75de1 upgrade pytorch to 2.5.1 (#2024) 2024-11-08 10:46:24 -05:00
Wing Lian
035e9f9dd7 janky workaround to install FA2 on torch 2.5.1 base image since it takes forever to build (#2022) 2024-11-07 17:54:29 -05:00
Wing Lian
02ce520b7e upgrade liger to 0.4.0 (#1973)
* upgrade liger to 0.3.1

* update docs and example

* skip duplicate code check

* Update src/axolotl/integrations/liger/args.py

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update README.md

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* add logging

* chore: lint

* add test case

* upgrade liger and transformers

* also upgrade accelerate

* use kwargs to support patch release

* make sure prepared path is empty for test

* use transfromers 4.46.1 since 4.46.2 breaks fsdp

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-07 12:53:34 -05:00
Wing Lian
052a9a79b4 only run the remainder of the gpu test suite if one case passes first (#2009) [skip ci]
* only run the remainder of the gpu test suite if one case passes first

* also reduce the test matrix
2024-10-31 13:45:01 -04:00
Wing Lian
3591bcfaf9 add torch 2.5.1 for base image (#2010) 2024-10-31 13:27:49 -04:00
Wing Lian
dc1de7d81b add retries for load datasets requests failures (#2007) 2024-10-31 13:26:14 -04:00
Chirag Jain
d4dbfa02fe Add plugin manager's callback hooks to training flow (#2006)
* Add plugin manager's callback hooks to training flow

* Use .values() instead of .items()
2024-10-31 12:13:46 -04:00
NanoCode012
5c7e89105d Fix: modelloader handling of model_kwargs load_in*bit (#1999)
* fix: load_in_*bit not properly read

* fix: load_*bit check

* fix: typo

* refactor: load * bit handling

* feat: add test dpo lora multi-gpu

* fix: turn off sample packing for dpo

* fix: missing warmup_steps

* fix: test to load in 8bit for lora

* skip 8bit lora on h100, add 4bit lora on h100 to multi gpu tests

* chore: reduce max_steps

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-30 14:41:34 -04:00
Chirag Jain
74db2a1bae Fix get_chat_template call for trainer builder (#2003) 2024-10-30 14:27:00 -04:00
Geun, Lim
e62554c419 feat: add Exaone3 chat_template (#1995) 2024-10-30 12:30:12 -04:00
Wing Lian
32c60765ef remove skipped test (#2002)
* remove skipped test

* use mean_resizing_embeddings with qlora and added tokens

* use </s> as pad_token to prevent resize of embeddings

* make sure local hub test saves to a tmp dir

* use Path so concatenation works

* make sure to use tmp_ds_path for data files
2024-10-30 12:27:04 -04:00
NanoCode012
8c3a727f9d feat: update yml chat_template to specify dataset field (#2001) [skip ci]
* feat: update yml chat_template to specify dataset field

* feat: replace sharegpt references with chat_template
2024-10-29 10:26:03 -04:00
Oliver Kunc
107b67b852 Hardware requirements (#1997) [skip ci]
* Hardware requirements

https://github.com/axolotl-ai-cloud/axolotl/issues/1992

* Update README.md

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-29 10:13:50 -04:00
NanoCode012
bfc77b0f36 Feat: Add support for tokenizer’s or custom jinja chat_template (#1970)
* Allow using tokenizer's default chat template with fallbacks

Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests

* Add tests

* Fix lint and bug post merge from main

* Add option `chat_template_jinja` to provide a jinja template

* remove custom mistral template

* Address review comments and add docs

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* fix: set default to tokenizer template

* Merge branch 'main' into cj_tokenizer_default_prompt_template

* chore: remove redundant function

* fix: re-arrange enum declaration position

* fix: refactor artifact left from main merge

* feat(doc): updated config with chat template options and clarified examples

* chore: clarify doc

* chore: added example for non-default template

* chore: refactor

* fix: test

* fix: config being dropped and unittest to catch that

* chore: lint

* chore: skip duplicate

* fix: rename var after merge

* feat: add test for levy's dpo case

* fix: remove default setting on edge case where chat template overriden in dataset section

* feat: handle sharegpt deprecation better in docs

* feat: add example using fallback

* feat: handles chat_template requiring specific user/assistant order

* fix: update test based on new defaults

* fix: imported name incorrectly updated on merge

* chore: lint

* fix: update dummy message to prevent potential overlap with real content

* fix(doc): formatting

* fix: update bradleyterry to use new chat_template

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-10-29 10:14:51 +07:00
Wing Lian
e1e0556c99 add option for resizing embeddings when adding new tokens (#2000)
* add option for resizing embeddings when adding new tokens

* let's just be opinonated about this setting and set it to False
2024-10-28 17:02:04 -04:00
Wing Lian
d3c45d27b5 fix zero3 (#1994) 2024-10-28 07:32:49 -04:00
NanoCode012
2501c1a6a3 Fix: Gradient Accumulation issue (#1980)
* feat: support new arg num_items_in_batch

* use kwargs to manage extra unknown kwargs for now

* upgrade against upstream transformers main

* make sure trl is on latest too

* fix for upgraded trl

* fix: handle trl and transformer signature change

* feat: update trl to handle transformer signature

* RewardDataCollatorWithPadding no longer has max_length

* handle updated signature for tokenizer vs processor class

* invert logic for tokenizer vs processor class

* processing_class, not processor class

* also handle processing class in dpo

* handle model name w model card creation

* upgrade transformers and add a loss check test

* fix install of tbparse requirements

* make sure to add tbparse to req

* feat: revert kwarg to positional kwarg to be explicit

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-25 11:28:23 -04:00
Mengqing Cao
1d6a5e2bd6 Refactor func load_model to class ModelLoader (#1909) 2024-10-25 09:06:56 -04:00
Wing Lian
718cfb2dd1 revert image tagged as main-latest (#1990) 2024-10-22 13:54:24 -04:00
Adam Hazell
9bd5f7d015 Log checkpoints as mlflow artifacts (#1976)
* Ensure hf_mlflow_log_artifact config var is set in env

* Add transformer MLflowCallback to callbacks list when mlflow enabled

* Test hf_mlflow_log_artifacts is set correctly

* Test mlflow not being used by default
2024-10-22 08:52:21 -04:00
Wing Lian
5c629ee444 use torch 2.4.1 images as latest now that torch 2.5.0 is out (#1987) 2024-10-21 19:51:06 -04:00
Wing Lian
955cca41fc don't explicitly set cpu pytorch version (#1986)
use a constraint file
use min version of xformers
don't install autoawq with pytorch 2.5.0
debugging for errors
upgrade pip first
fix action yml
add back try/except
retry w/o constraint
use --no-build-isolation
show torch version
install setuptools and wheel
add back try/except
2024-10-21 19:50:50 -04:00
Wing Lian
e12a2130e9 first pass at pytorch 2.5.0 support (#1982)
* first pass at pytorch 2.5.0 support

* attempt to install causal_conv1d with mamba

* gracefully handle missing xformers

* fix import

* fix incorrect version, add 2.5.0

* increase tests timeout
2024-10-21 11:00:45 -04:00
Wing Lian
67f744dc8c add pytorch 2.5.0 base images (#1979)
* add pytorch 2.5.0 base images

* make sure num examples for debug is zero and fix comparison
2024-10-18 03:36:51 -04:00
Sunny Liu
f62e23737b memoize dataset length for eval sample packing (#1974)
* wip on multimodal sample packing support

* wip on multimodal packing support

* llama-1b-yml

* setup logging for test

* yml

* yml

* yml

* fix for __len__ for eval sample packing

* reverted irrelavant changes

* reformatted, reverted log message

* reverted unnecessary changes

* added e2e multigpu testing for eval sample packing

* formatting

* fixed e2e test_eval params

* fix test_eval e2e multigpu

* fix test_eval e2e multigpu

* Update tests/e2e/multigpu/test_eval.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update tests/e2e/multigpu/test_eval.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-17 15:15:29 -04:00
Wing Lian
54673fd6ca also debug if other debug args are set (#1977) 2024-10-17 14:12:31 -04:00
126 changed files with 5074 additions and 3649 deletions

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
python_version: "3.10"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
@@ -36,21 +36,29 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v3
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-base
images: |
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
with:

View File

@@ -17,7 +17,7 @@ jobs:
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
- name: Setup Python
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: install dependencies

View File

@@ -15,9 +15,9 @@ jobs:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1

View File

@@ -4,11 +4,13 @@ on:
push:
branches:
- "main"
tags:
- "v*"
workflow_dispatch:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -29,6 +31,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -37,7 +44,12 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
@@ -51,7 +63,7 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
@@ -65,7 +77,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -86,6 +98,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -94,20 +111,25 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud
push: ${{ github.event_name != 'pull_request' }}
@@ -118,7 +140,7 @@ jobs:
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -136,20 +158,25 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud-term
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }}

View File

@@ -8,9 +8,14 @@ on:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -21,10 +26,17 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 121
cuda_version: 12.1.1
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.3.1
pytorch: 2.4.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -7,7 +7,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -28,6 +28,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -36,7 +41,9 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Set up Docker Buildx
@@ -64,7 +71,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -85,6 +92,11 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -93,7 +105,9 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Login to Docker Hub
@@ -102,7 +116,7 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:

View File

@@ -3,12 +3,33 @@ name: publish pypi
on:
push:
tags:
- '*'
- 'v*'
workflow_dispatch:
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup_release]
environment:
name: pypi
url: https://pypi.org/p/axolotl
@@ -16,10 +37,10 @@ jobs:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.10"
@@ -27,7 +48,7 @@ jobs:
run: |
pip3 install wheel packaging
pip3 install -e .
pip3 install -r requirements-tests.txt
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name
id: tag
@@ -37,9 +58,9 @@ jobs:
run: |
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a binary wheel
- name: Build a source dist
run: |
python setup.py sdist bdist_wheel
python setup.py sdist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -9,12 +9,12 @@ jobs:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -25,15 +25,15 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
@@ -47,13 +47,15 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-tests.txt
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
@@ -81,13 +83,6 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -95,6 +90,13 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -15,17 +15,22 @@ on:
- '.github/workflows/*.yml'
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -36,61 +41,97 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 show torch
pip3 install -U -e .
pip3 install -r requirements-tests.txt
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
pytest --ignore=tests/e2e/ tests/
pytest -n8 --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
needs: [pre-commit, pytest]
timeout-minutes: 90
needs: [pre-commit, pytest, pytest-sdist]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -119,3 +160,49 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests

4
MANIFEST.in Normal file
View File

@@ -0,0 +1,4 @@
include requirements.txt
include README.md
include LICENSE
recursive-include axolotl *.py

View File

@@ -1,8 +1,21 @@
# Axolotl
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg)
![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg)
![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg)
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
</p>
<p align="center">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
@@ -75,7 +88,7 @@ Features:
<td>
<div align="center">
<img src="image/axolotl.png" alt="axolotl" width="160">
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
@@ -121,7 +134,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -159,7 +172,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Or run on the current files for development:
@@ -178,7 +191,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
```
It additionally:
@@ -210,7 +223,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
#### Cloud GPU
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
@@ -319,7 +332,7 @@ Write a job description in YAML as below:
# dstack.yaml
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
image: axolotlai/axolotl-cloud:main-latest
env:
- HUGGING_FACE_HUB_TOKEN
@@ -383,11 +396,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
- path: ...
type: sharegpt
conversation: chatml # default: vicuna_v1.1
type: chat_template
chat_template: chatml # defaults to tokenizer's chat_template
# local
- path: data.jsonl # or json
@@ -562,7 +574,8 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```

View File

@@ -1,4 +1,4 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
@@ -23,11 +23,12 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,7 +38,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi
# So we can test the Docker image
RUN pip install -r requirements-tests.txt
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -46,7 +46,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
app = App("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 2))
@@ -61,10 +61,10 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)
@@ -72,6 +72,6 @@ def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
@stub.local_entrypoint()
@app.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -2,4 +2,4 @@
set -e
# only run one test at a time so as not to OOM the GPU
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -47,7 +47,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
app = App("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
@@ -62,10 +62,10 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
timeout=60 * 60,
cpu=8.0,
memory=131072,
)
@@ -73,6 +73,6 @@ def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
@app.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -14,15 +14,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -24,15 +24,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,15 +20,6 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -1,4 +1,4 @@
# Example config for debugging the sharegpt prompt format
# Example config for debugging the chat_template prompt format
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- path: philschmid/guanaco-sharegpt-style
type: sharegpt
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base
FROM winglian/axolotl-base:$BASE_TAG
FROM axolotlai/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
@@ -20,7 +20,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base
FROM winglian/axolotl-base:$BASE_TAG
FROM axolotlai/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""

View File

@@ -83,7 +83,7 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
@@ -91,15 +91,7 @@ datasets:
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
# Custom user instruction prompt
- path: repo
@@ -124,6 +116,48 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
@@ -141,10 +175,19 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
@@ -363,6 +406,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd

View File

@@ -6,31 +6,8 @@ order: 3
## sharegpt
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
```{.json filename="data.jsonl"}
{"conversations": [{"from": "...", "value": "..."}]}
```
Note: `type: sharegpt` opens special configs:
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
- `field_human`: specify the key to use instead of `human` in the conversation.
- `field_model`: specify the key to use instead of `gpt` in the conversation.
```yaml
datasets:
path: ...
type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
```
## pygmalion
@@ -38,34 +15,137 @@ datasets:
{"conversations": [{"role": "...", "value": "..."}]}
```
## sharegpt.load_role
conversations where `role` is used instead of `from`
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
{"conversations": [{"role": "...", "content": "..."}]}
```
## sharegpt.load_guanaco
See `config.qmd` for full configs and supported templates.
conversations where `from` is `prompter` `assistant` instead of default sharegpt
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{"conversations": [{"from": "...", "value": "..."}]}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
## sharegpt.load_ultrachat
The configuration would look like:
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
```{.json filename="data.jsonl"}
{"messages": [{"user": "...", "assistant": "..."}]}
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: sharegpt
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
type: chat_template
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
"name": "Debug axolotl prompt - sharegpt",
"name": "Debug axolotl prompt - chat_template",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
## Debugging With Docker
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
### Setup
@@ -202,11 +202,11 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
```
>[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl:
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -44,7 +44,7 @@
"outputs": [],
"source": [
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install flash-attn==\"2.7.0.post2\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
]
},

View File

@@ -9,14 +9,17 @@ strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -11,8 +11,11 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -4,11 +4,15 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
chat_template: jamba
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
strict: false
@@ -14,6 +14,10 @@ datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./outputs/out

View File

@@ -0,0 +1,77 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -0,0 +1,93 @@
#Note that we are switching from the regular chat template to chatml.
#If you experience problems with the special tokens, training for more epochs can help.
#After training, merge the model before inference otherwise you might
#face problems with the special tokens.
base_model: mistralai/Mistral-7B-Instruct-v0.2
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
chat_template: chatml
rl: dpo
datasets:
- path: olivermolenschot/alpaca_messages_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-qlora
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.2
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 16
num_epochs: 6
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"

View File

@@ -10,7 +10,6 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content

67
examples/qwen2/dpo.yaml Normal file
View File

@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-0.5B
strict: false
chat_template: qwen_25
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/dpo-out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 24 KiB

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
<path fill="#141310" d="M435,234.3l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185.1h31.6l47.9,185.1h-24.5ZM417.7,164.9l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path fill="#141310" d="M568.2,234.3l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
<path fill="#141310" d="M658.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM658.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#141310" d="M860.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM860.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#141310" d="M773.9,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
<path fill="#141310" d="M1036.2,234.3V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.8v-24.1h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
<path fill="#141310" d="M978.6,234.3c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3v-45.3h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
<path fill="#141310" d="M51.5,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v32.8h20.6v-32.8c0-4.7,3.8-8.4,8.4-8.4Z"/>
<path fill="#141310" d="M92.8,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v12.2h20.6v-12.2c0-4.7,3.8-8.4,8.4-8.4Z"/>
<path fill="#141310" d="M249.3,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v32.8h20.6v-32.8Z"/>
<path fill="#141310" d="M187.4,90.2v-20.6h-103.1v20.6h-41.2v20.6h-20.6v41.2c0,11.4,9.2,20.6,20.6,20.6h185.5c11.4,0,20.6-9.2,20.6-20.6v-41.2h-20.6v-20.6h-41.2ZM166.8,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3ZM228.7,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3Z"/>
<path fill="#141310" d="M208,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v12.2h20.6v-12.2Z"/>
<rect fill="#141310" x="22.5" y="234.5" width="41.2" height="20.6"/>
<rect fill="#141310" x="84.3" y="234.5" width="164.9" height="20.6"/>
<rect fill="#141310" x="208" y="193.3" width="41.2" height="20.6"/>
<rect fill="#141310" x="22.5" y="193.3" width="164.9" height="20.6"/>
</svg>

After

Width:  |  Height:  |  Size: 3.2 KiB

View File

@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
<path fill="#fff" d="M462.9,234.2l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185h31.6l47.9,185h-24.4ZM445.7,164.8l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path fill="#fff" d="M596.1,234.2l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.5-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.3,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.1,49.3,71.6h-28.5Z"/>
<path fill="#fff" d="M686.4,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM686.4,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#fff" d="M888.3,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM888.3,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#fff" d="M801.7,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.5c0,4.7,3.8,8.5,8.5,8.5h16.7v24.1h-16.7Z"/>
<path fill="#fff" d="M1063.8,234.2V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.7v-24.1h16.7c18,0,32.6,14.6,32.6,32.6v152.8h-24.1Z"/>
<path fill="#fff" d="M1006.2,234.2c-18,0-32.6-14.6-32.6-32.6v-85h-20.3v-22.1h20.3v-45.2h24.1v45.2h30.2v22.1h-30.2v85c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
<path fill="#fff" d="M160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM277.3,57.4c0-23.8-19.3-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.7,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.7-6.3-14.1-14.1-14.1h-12.2c-6.5,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.3-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.2c0,11,5.2,20.8,13.2,27.2-7.3.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.7,6.3,14.1,14.1,14.1h41.2c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h164.9c7.7,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.8-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.2c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM77.8,255.1h-41.2v-20.6h41.2v20.6ZM36.5,213.9v-20.6h164.9v20.6H36.5ZM263.3,255.1H98.4v-20.6h164.9v20.6ZM263.3,213.9h-41.2v-20.6h41.2v20.6ZM263.3,90.2h-20.6v20.6h20.6v41.2c0,11.4-9.2,20.6-20.6,20.6H57.2c-11.4,0-20.6-9.2-20.6-20.6v-41.2h20.6v-20.6h-20.6v-32.8c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.2v-20.6h-20.6v-12.2c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.1v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v12.2h-20.6v20.6h41.2v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v32.8ZM201.4,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
</svg>

After

Width:  |  Height:  |  Size: 6.6 KiB

View File

@@ -0,0 +1,26 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
<defs>
<style>
.cls-1 {
fill: #141310;
}
</style>
</defs>
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<g>
<path class="cls-1" d="M46.9,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v36.9h23.2v-36.9c0-5.2,4.2-9.5,9.5-9.5Z"/>
<path class="cls-1" d="M93.2,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v13.7h23.2v-13.7c0-5.2,4.2-9.5,9.5-9.5Z"/>
<path class="cls-1" d="M269.3,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v36.9h23.2v-36.9Z"/>
<path class="cls-1" d="M199.7,83.8v-23.2h-116v23.2h-46.4v23.2H14.2v46.4c0,12.8,10.4,23.2,23.2,23.2h208.7c12.8,0,23.2-10.4,23.2-23.2v-46.4h-23.2v-23.2h-46.4ZM176.5,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6ZM246.1,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6Z"/>
<path class="cls-1" d="M222.9,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v13.7h23.2v-13.7Z"/>
<rect class="cls-1" x="14.2" y="246.1" width="46.4" height="23.2"/>
<rect class="cls-1" x="83.8" y="246.1" width="185.5" height="23.2"/>
<rect class="cls-1" x="222.9" y="199.7" width="46.4" height="23.2"/>
<rect class="cls-1" x="14.2" y="199.7" width="185.5" height="23.2"/>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
<defs>
<style>
.cls-1 {
fill: #fff;
}
</style>
</defs>
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<path class="cls-1" d="M152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM269.3,57.3c0-23.8-19.4-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.8,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.8-6.3-14.1-14.1-14.1h-12.2c-6.6,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.4-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.3c0,11,5.2,20.9,13.2,27.2-7.4.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.8,6.3,14.1,14.1,14.1h41.3c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h165.1c7.8,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.9-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.3c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM69.5,255.2H28.2v-20.6h41.3v20.6ZM28.2,214v-20.6h165.1v20.6H28.2ZM255.2,255.2H90.1v-20.6h165.1v20.6ZM255.2,214h-41.3v-20.6h41.3v20.6ZM255.2,90.1h-20.6v20.6h20.6v41.3c0,11.4-9.2,20.6-20.6,20.6H48.9c-11.4,0-20.6-9.2-20.6-20.6v-41.3h20.6v-20.6h-20.6v-32.8c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.3v-20.6h-20.6v-12.2c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.2v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v12.2h-20.6v20.6h41.3v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v32.8ZM193.3,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<g>
<path d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
<path d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
<path d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
<path d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -0,0 +1,24 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
<defs>
<style>
.cls-1 {
fill: #fff;
}
</style>
</defs>
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<g>
<path class="cls-1" d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path class="cls-1" d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
<path class="cls-1" d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path class="cls-1" d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path class="cls-1" d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
<path class="cls-1" d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
<path class="cls-1" d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.3 KiB

View File

@@ -2,3 +2,4 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,2 +1,3 @@
pytest
pytest-xdist
pytest-retry

View File

@@ -1,22 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.45.2
transformers==4.46.3
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.14.4
accelerate==1.1.0
datasets==3.1.0
deepspeed==0.15.4
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
flash-attn==2.6.3
flash-attn==2.7.0.post2
sentencepiece
wandb
einops
xformers==0.0.28.post1
xformers>=0.0.23.post1
optimum==1.16.2
hf_transfer
colorama
@@ -28,13 +28,12 @@ scipy
scikit-learn==1.4.2
pynvml
art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
autoawq==0.2.7.post2
triton>=2.3.0
liger-kernel==0.3.0
liger-kernel==0.4.1
mamba-ssm==1.2.0.post1
@@ -43,7 +42,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.9.6
trl==0.12.0
zstandard==0.22.0
fastcore
@@ -54,3 +53,4 @@ immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0
schedulefree==1.3.0

View File

@@ -2,7 +2,7 @@
# Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() {

View File

@@ -31,13 +31,18 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
torch_version = version("torch")
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -50,10 +55,20 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 4):
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
@@ -73,7 +88,6 @@ def parse_requirements():
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
@@ -82,26 +96,24 @@ install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.4.1",
version="0.5.1",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages(),
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.6.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
"flash-attn==2.7.0.post2",
],
"deepspeed": [
"deepspeed==0.14.4",
"deepspeed==0.15.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",

View File

@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -190,18 +190,15 @@ def do_inference(
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -211,13 +208,31 @@ def do_inference(
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
@@ -257,13 +272,6 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
@@ -272,7 +280,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -462,7 +470,12 @@ def load_datasets(
processor=processor,
)
if cli_args.debug or cfg.debug:
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(

View File

@@ -23,10 +23,6 @@ from axolotl.cli import (
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -44,23 +40,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True
)
if parsed_cfg.chat_template == "chatml":
if parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED

View File

@@ -19,10 +19,6 @@ from axolotl.cli import (
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
@@ -42,21 +38,6 @@ def do_train(cfg, cli_args) -> None:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.chat_template == "chatml" and cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {cfg.default_system_message}"
)
register_chatml_template(cfg.default_system_message)
else:
register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)

View File

@@ -7,6 +7,7 @@ import abc
import gc
import importlib
import importlib.util
import inspect
import logging
import math
import os
@@ -27,7 +28,6 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -48,6 +48,7 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -63,7 +64,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -435,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
@@ -504,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -666,7 +681,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
@@ -674,8 +691,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -771,7 +798,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(self, model, inputs, return_outputs=False):
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
@@ -877,13 +910,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -898,6 +931,7 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
@@ -1004,19 +1038,46 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs)
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
@@ -1114,17 +1175,28 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -1135,11 +1207,17 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1185,7 +1263,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1222,6 +1300,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def _get_trainer_cls(self):
@@ -1339,17 +1429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
training_arguments_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
elif self.cfg.eval_strategy:
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
training_arguments_kwargs["eval_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
@@ -1556,8 +1644,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
self.cfg.chat_template
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo":
@@ -1573,11 +1662,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1662,12 +1753,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
@@ -1708,6 +1804,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
@@ -1745,7 +1843,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -1773,10 +1871,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.eval_dataset:
training_args_kwargs["evaluation_strategy"] = "steps"
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["evaluation_strategy"] = "no"
training_args_kwargs["eval_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
@@ -1831,17 +1929,18 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -1850,13 +1949,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
if self.cfg.rl == "orpo":
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
if self.cfg.rl == "kto":
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -1871,6 +1970,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1891,7 +2001,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
@@ -1905,12 +2014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -1922,11 +2025,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
@@ -1948,11 +2057,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):

View File

@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import List
from typing import OrderedDict
class BasePlugin:
@@ -47,7 +48,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg):
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
@@ -63,7 +64,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg):
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
@@ -74,7 +75,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model):
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -86,7 +87,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model):
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
@@ -98,7 +99,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model):
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
@@ -110,7 +111,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer):
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -122,7 +123,9 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -135,9 +138,9 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model):
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer before training.
setup callbacks before creating the trainer.
Parameters:
cfg (dict): The configuration for the plugin.
@@ -146,20 +149,25 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(self, cfg, trainer):
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after training.
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
List[callable]: A list of callback functions to be added
"""
return []
def post_train(self, cfg, model):
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
@@ -171,7 +179,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg):
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
@@ -227,7 +235,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: List[BasePlugin] = []
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
@@ -237,7 +245,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
cls._instance.plugins = collections.OrderedDict()
return cls._instance
@staticmethod
@@ -265,7 +273,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
self.plugins[plugin_name] = plugin
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -277,7 +285,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
for plugin in self.plugins.values():
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -293,7 +301,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -307,7 +315,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -321,7 +329,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -335,7 +343,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -349,7 +357,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -367,7 +375,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -385,8 +393,10 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
for plugin in self.plugins.values():
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
if plugin_callbacks: # if the plugin returned a list of callbacks
callbacks.extend(plugin_callbacks)
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
@@ -401,8 +411,10 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
for plugin in self.plugins.values():
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
if plugin_callbacks:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train_unload(self, cfg):
@@ -416,5 +428,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,13 @@
# Grokfast Optimizer
See https://github.com/ironjr/grokfast
### Usage
```yaml
plugins:
- axolotl.integrations.grokfast.GrokfastPlugin
grokfast_alpha: 2.0
grokfast_lamb: 0.98
```

View File

@@ -0,0 +1,50 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
class GrokfastCallbackHandler(TrainerCallback):
"""
Transformer trainer callbacks for Grokfast
"""
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
super().__init__(*args_, **kwargs)
self.grads = None
self.alpha = alpha
self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
self.grads = None
def on_pre_optimizer_step(
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control
class GrokfastPlugin(BasePlugin):
"""
Plugin for Grokfast optimizer integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.grokfast.GrokfastArgs"
def add_callbacks_post_trainer(self, cfg, trainer):
LOG.info("Adding Grokfast callback to the trainer")
callback = GrokfastCallbackHandler(
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
)
return [callback]

View File

@@ -0,0 +1,15 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel
class GrokfastArgs(BaseModel):
"""
Input args for Grokfast optimizer.
"""
grokfast_alpha: Optional[float] = 0.98
grokfast_lamb: Optional[float] = 2.0

View File

@@ -0,0 +1,63 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque
from typing import Dict, Literal, Optional
import torch
import torch.nn as nn
def gradfilter_ma(
m: nn.Module,
grads: Optional[Dict[str, deque]] = None,
window_size: int = 100,
lamb: float = 5.0,
filter_type: Literal["mean", "sum"] = "mean",
warmup: bool = True,
trigger: bool = False, # For ablation study.
) -> Dict[str, deque]:
if grads is None:
grads = {
n: deque(maxlen=window_size)
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n].append(p.grad.data.detach()) # .cpu())
# Modify the gradients.
if not warmup or len(grads[n]) == window_size and not trigger:
if filter_type == "mean":
avg = sum(grads[n]) / len(grads[n])
elif filter_type == "sum":
avg = sum(grads[n])
else:
raise ValueError(f"Unrecognized filter_type {filter_type}")
p.grad.data = p.grad.data + avg * lamb
return grads
def gradfilter_ema(
m: nn.Module,
grads: Optional[Dict[str, torch.Tensor]] = None,
alpha: float = 0.98,
lamb: float = 2.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {
n: p.grad.data.detach()
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb
return grads

View File

@@ -18,20 +18,24 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -42,59 +46,31 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -104,30 +80,14 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
@@ -146,44 +106,11 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,9 +15,12 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
class LigerArgs(BaseModel):
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

View File

@@ -1,231 +0,0 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
if (i % 2 == 0 and not self.system_message) or (
i % 2 != 0 and self.system_message
):
role = "<s> " + role
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
contains_sys_msg = False
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(
system_message=" " + self.system_message
)
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message and i == 0 and not contains_sys_msg:
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
elif message:
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATGLM3:
if self.system_message:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + "\n", " " + message
else:
yield role
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -22,7 +22,6 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -44,7 +43,19 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator
try:
@@ -57,6 +68,11 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
@@ -181,49 +197,6 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(

View File

@@ -1,4 +1,5 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
@@ -27,71 +28,28 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
# retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
def patch_remote(model_name, config_name, modeling_name):
def patch_remote(model_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
parts = model_config.__class__.__module__.split(".")
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
module_name = ".".join(parts)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
if hasattr(modeling_arch, "_get_unpad_data"):
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -16,26 +16,6 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -80,12 +60,6 @@ def get_forward_code() -> str:
return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -98,48 +72,31 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
from transformers.loss import loss_utils
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
else:
raise ValueError("Unsupported model type")

View File

@@ -0,0 +1,51 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -6,7 +6,7 @@ import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,13 +2,18 @@
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
@@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(

View File

@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -405,10 +405,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),

View File

@@ -2,15 +2,16 @@
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -30,6 +31,12 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages]
messages = [
{
@@ -46,28 +53,29 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
[dummy_user_message, chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
[dummy_user_message, rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -1,33 +0,0 @@
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = InstructShareGPTPromptTokenizingStrategy(
# pylint: disable=duplicate-code
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return [
{"from": "human", "value": prompt["instruction"]},
{"from": "gpt", "value": prompt["output"]},
]

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
@dataclass
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
Tokenizing strategy for Llama2 prompts.
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
"""
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE
if sentence["value"]:
conv.append_message(role, sentence["value"])
yield conv

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
class Message(BaseModel):
@@ -28,18 +28,13 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
tokenizer.chat_template = chat_template_string
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template, tokenizer),
ORPOPrompter(chat_template_string, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -248,28 +243,30 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]

View File

@@ -1,223 +0,0 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)
LOG = logging.getLogger("axolotl")
def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
def register_llama3_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="llama3",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
system_message=system_message,
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str="<|eot_id|>",
stop_token_ids=[128001, 128009],
)
)
def build_loader(
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
prompter_cls: Type["ShareGPTPrompterV2"],
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
return _load
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = False
_messages = "conversations"
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def get_conversation_thread(self, prompt):
conversations = prompt[self.messages]
if self.strict:
return conversations
role_key = "from"
if "role" in conversations[0].keys():
role_key = "role"
value_key = "value"
if "text" in conversations[0].keys():
value_key = "text"
elif "content" in conversations[0].keys():
value_key = "content"
# remap roles - allow for assistant turn"
role_map = {
"user": "human",
"human": "human",
"assistant": "gpt",
"gpt": "gpt",
"system": "system",
}
turns = [
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
"weight": 1
if "weight" not in t or t["weight"] is None
else t["weight"],
}
for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps ultrachat data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["messages"]
role_map = {"user": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)
return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -1,28 +0,0 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
Tokenization strategy for asking bot to tell a joke and then explain why its funny
"""
# title, text, explanation
def get_conversation_thread(self, prompt):
title = "" if not prompt["title"] else prompt["title"] + " "
return [
{"from": "human", "value": "Tell me a joke."},
{"from": "gpt", "value": title + prompt["text"]},
{"from": "human", "value": "Why is that joke funny?"},
{"from": "gpt", "value": prompt["explanation"]},
]

View File

@@ -1,17 +1,12 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""
import abc
import copy
import logging
from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.prompters import Prompter
LOG = logging.getLogger("axolotl")
@@ -21,8 +16,6 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
add_get_turns_to_conversation()
class InvalidDataException(Exception):
"""
@@ -331,154 +324,6 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
)
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
)
input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}
if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)
# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)
if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue
if len(part) <= 2:
role, content = part
weight = 1
else:
role, content, weight = part
# Uses "in" because role contains extra characters
input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""
if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue
if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token = not (
conversation.name == "chatml"
and conversation.sep == self.tokenizer.eos_token
)
res = self._tokenize(
turn,
add_eos_token=add_eos_token,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
labels = copy.deepcopy(res["input_ids"])
if not self.train_on_inputs:
# mask out role tokens from the labels
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
if weight == 0:
# everything from this is masked out from the labels
# (role is masked out too because it makes no sense if contents is masked out)
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
"""
Returns the default values for the tokenize prompt function

View File

@@ -5,7 +5,6 @@ from enum import Enum
from typing import Generator, Optional, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -262,166 +261,10 @@ class ReflectAlpacaPrompter(Prompter):
)
SHAREGPT_ASSERTION_FAILED_ROLE = (
ALTERNATING_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
}
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
"""
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None
def __init__(
self,
prompt_style=None, # pylint: disable=unused-argument
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
if conversation:
if isinstance(conversation, Conversation):
self._conversation = conversation
else:
self._conversation = get_conv_template(conversation)
else:
self._conversation = get_conv_template("vicuna_v1.1")
if role_key_human:
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
if roles:
self.roles = roles
def _build_result(self, source):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy()
original_source = source.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
if self.role_key_tool:
roles[self.role_key_tool] = conv.roles[2]
try:
# Apply prompt templates
if source[0]["from"] not in roles:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as err:
# sometimes there is a bing or system chat
raise err
conv.messages = []
for _, sentence in enumerate(source):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)
if self._conversation.name in ["llama3"]:
role = from_role
else:
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
turns = list(conv.get_turns())
original_source_length = len(original_source)
assert len(turns) in [
original_source_length - 1,
original_source_length,
original_source_length + 1,
]
if len(turns) == original_source_length + 1:
original_source = [{"weight": None}] + original_source
elif len(turns) == original_source_length - 1:
original_source = original_source[1:]
return [
(*turn, weight)
for turn, weight in zip(
turns,
[
1 if "weight" not in e or e["weight"] is None else e["weight"]
for e in original_source
],
)
]
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter):
"""
A V2 prompter that generates prompts for the ShareGPT
"""
def __init__(
self,
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
role_key_tool=role_key_tool,
roles=roles,
)
class UnsupportedPrompter(Prompter):
"""

View File

@@ -260,8 +260,10 @@ def train(
if not cfg.hub_model_id:
try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -64,10 +64,7 @@ class EvalFirstStepCallback(
control: TrainerControl,
**kwargs,
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
control.should_evaluate = True
return control

File diff suppressed because one or more lines are too long

View File

@@ -1,8 +1,6 @@
"""Module for working with config dicts"""
import json
import logging
import os
from pathlib import Path
from typing import Optional
import torch
@@ -10,7 +8,6 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
@@ -215,11 +212,6 @@ def normalize_cfg_datasets(cfg):
if cfg.chat_template:
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
LOG.info(
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].conversation = cfg.chat_template
if (
ds_cfg.type in ["orpo.chat_template", "chat_template"]
and not ds_cfg.chat_template
@@ -228,6 +220,7 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
@@ -251,391 +244,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)
def legacy_validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
"""
if is_torch_bf16_gpu_available():
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if (
not cfg.merge_lora
and not cfg.is_preprocess
and (cfg.bf16 is True or cfg.bfloat16 is True)
):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if (
# pylint: disable=too-many-boolean-expressions
not (cfg.bf16 or cfg.bfloat16)
and (cfg.fp16 or cfg.float16)
and not cfg.adapter
and not cfg.flash_attention
and cfg.sample_packing
):
LOG.warning(
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
)
# ValueError: Attempting to unscale FP16 gradients.
# OR
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
if cfg.max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
if cfg.sample_packing and cfg.rl:
raise ValueError("`sample_packing: true` does not work with RLHF training")
if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
LOG.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if (
cfg.eval_batch_size
and cfg.micro_batch_size
and cfg.eval_batch_size != cfg.micro_batch_size
):
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
if cfg.adapter == "qlora":
if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if cfg.gptq:
raise ValueError("Can't merge qlora if gptq")
if cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if cfg.gptq:
raise ValueError("Can't load qlora if gptq")
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
)
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.fsdp:
raise ValueError("fsdp not supported with ReLoRA")
if cfg.deepspeed:
raise ValueError("deepspeed not supported with ReLoRA")
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bfloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
LOG.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
if cfg.push_to_hub_model_id:
raise ValueError(
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
)
if cfg.gptq and cfg.revision_of_model:
raise ValueError(
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove revision_of_model from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
# raise ValueError(
# "sample_packing not compatible with sdp_attention. Use flash_attention"
# )
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if cfg.save_steps % cfg.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if not ds_cfg.type:
continue
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.saves_per_epoch and cfg.save_steps:
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
cfg.evals_per_epoch
and cfg.evaluation_strategy
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
cfg.val_set_size == 0
and (cfg.eval_steps or cfg.evaluation_strategy)
and not cfg.test_datasets
):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if (
cfg.sample_packing
and cfg.eval_table_size
and cfg.eval_sample_packing is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
if cfg.noisy_embedding_alpha is not None:
# Deprecated, use neftune_noise_alpha
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
if cfg.neftune_noise_alpha is None:
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
else:
# User is providing both; bail and have them sort out their settings
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
if (
cfg.unfrozen_parameters
and cfg.gradient_checkpointing_kwargs
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
):
# https://github.com/huggingface/transformers/issues/21381
raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
with open(cfg.deepspeed, encoding="utf-8") as file:
contents = file.read()
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
if cfg.flash_attention:
if (
deepspeed_cfg.zero_optimization
and deepspeed_cfg.zero_optimization.stage == 3
):
if not (
(
deepspeed_cfg.bf16
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
is True
)
or (
deepspeed_cfg.fp16
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
is True
)
):
raise ValueError(
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
)
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
LOG.warning(
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
)
if cfg.test_datasets and cfg.val_set_size:
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if cfg.eval_causal_lm_metrics:
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -8,9 +8,16 @@ import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -21,6 +28,39 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -28,6 +68,7 @@ class DeprecatedParameters(BaseModel):
rope_scaling: Optional[Any] = None
noisy_embedding_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
evaluation_strategy: Optional[str] = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -59,6 +100,13 @@ class DeprecatedParameters(BaseModel):
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
return dpo_beta
@field_validator("evaluation_strategy")
@classmethod
def validate_evaluation_strategy(cls, evaluation_strategy):
if evaluation_strategy is not None:
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
return evaluation_strategy
class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
@@ -105,13 +153,19 @@ class SFTDataset(BaseModel):
input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
chat_template: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -122,13 +176,32 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -174,40 +247,13 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
@@ -250,8 +296,8 @@ class LoraConfig(BaseModel):
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
@@ -260,13 +306,15 @@ class LoraConfig(BaseModel):
loraplus_lr_ratio: Optional[float] = Field(
default=None,
metadata={
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: Optional[float] = Field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
)
merge_lora: Optional[bool] = None
@@ -336,10 +384,10 @@ class ModelInputConfig(BaseModel):
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
default=None, json_schema_extra={"description": "transformers tokenizer class"}
)
processor_type: Optional[str] = Field(
default=None, metadata={"help": "transformers processor class"}
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
@@ -361,18 +409,18 @@ class HyperparametersConfig(BaseModel):
gradient_accumulation_steps: Optional[int] = Field(default=1)
micro_batch_size: Optional[int] = Field(
default=1,
metadata={"help": "per gpu micro batch size for training"},
json_schema_extra={"description": "per gpu micro batch size for training"},
)
batch_size: Optional[int] = Field(
default=None,
metadata={
"help": "Total batch size, we do not recommended setting this manually"
json_schema_extra={
"description": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: Optional[int] = Field(
default=None,
metadata={
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
json_schema_extra={
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
@@ -392,16 +440,18 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
)
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
default=None,
metadata={
"help": "The target modules to optimize, i.e. the module names that you would like to train."
json_schema_extra={
"description": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
torchdistx_path: Optional[str] = None
@@ -461,15 +511,15 @@ class LISAConfig(BaseModel):
lisa_n_layers: Optional[int] = Field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
json_schema_extra={"description": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = Field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
json_schema_extra={"description": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = Field(
default="model.layers",
metadata={"help": "path under the model to access the layers"},
json_schema_extra={"description": "path under the model to access the layers"},
)
@@ -549,9 +599,13 @@ class AxolotlInputConfig(
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
@@ -564,7 +618,8 @@ class AxolotlInputConfig(
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_keep_in_memory: Optional[bool] = None
@@ -624,7 +679,8 @@ class AxolotlInputConfig(
sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
max_prompt_len: int = Field(
default=512, metadata={"help": "maximum prompt length for RL training"}
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
)
sample_packing: Optional[bool] = None
sample_packing_group_size: Optional[int] = 100_000
@@ -643,8 +699,8 @@ class AxolotlInputConfig(
pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
default=True,
metadata={
"help": "whether to prevent cross attention for packed sequences during pretraining",
json_schema_extra={
"description": "whether to prevent cross attention for packed sequences during pretraining",
},
)
@@ -690,7 +746,7 @@ class AxolotlInputConfig(
warmup_ratio: Optional[float] = None
eval_steps: Optional[Union[int, float]] = None
evals_per_epoch: Optional[Union[int]] = None
evaluation_strategy: Optional[str] = None
eval_strategy: Optional[str] = None
save_steps: Optional[Union[int, float]] = None
saves_per_epoch: Optional[int] = None
save_strategy: Optional[str] = None
@@ -718,7 +774,13 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[ChatTemplate] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
@@ -736,28 +798,25 @@ class AxolotlInputConfig(
is_mistral_derived_model: Optional[bool] = Field(default=None)
is_qwen_derived_model: Optional[bool] = Field(default=None)
plugins: Optional[List[str]] = Field(default=None)
@field_validator("datasets", mode="before")
@classmethod
def fix_sharegpt_datasets(cls, datasets):
for idx, ds_cfg in enumerate(datasets):
if not ds_cfg["type"]:
def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets):
if not ds_cfg.get("type"):
continue
if ds_cfg["type"] == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = "sharegpt"
if "sharegpt_simple" in ds_cfg["type"]:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = datasets[idx]["type"].replace(
"sharegpt_simple", "sharegpt"
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
raise ValueError(
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
)
return datasets
@model_validator(mode="before")
@@ -827,6 +886,23 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
@@ -972,21 +1048,21 @@ class AxolotlInputConfig(
@classmethod
def check_evals(cls, data):
if (
data.get("evaluation_strategy")
data.get("eval_strategy")
and data.get("eval_steps")
and data.get("evaluation_strategy") != "steps"
and data.get("eval_strategy") != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
)
if (
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("evaluation_strategy"))
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
"eval_steps and eval_strategy are not supported with val_set_size == 0"
)
if data.get("evals_per_epoch") and data.get("eval_steps"):
raise ValueError(
@@ -994,11 +1070,11 @@ class AxolotlInputConfig(
)
if (
data.get("evals_per_epoch")
and data.get("evaluation_strategy")
and data.get("evaluation_strategy") != "steps"
and data.get("eval_strategy")
and data.get("eval_strategy") != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (
@@ -1230,6 +1306,25 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def warn_qlora_zero3_w_use_reentrant(cls, data):
if (
data.get("adapter") == "qlora"
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
and "zero3" in data.get("deepspeed", "")
):
# may result in:
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
# Recomputed values for the following tensors have different metadata
# than during the forward pass.
LOG.warning(
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
)
return data
@model_validator(mode="before")
@classmethod
def check_val_w_test_datasets(cls, data):
@@ -1239,6 +1334,19 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_eval_strategy(cls, data):
if (
data.get("evaluation_strategy") is not None
and data.get("eval_strategy") is None
):
LOG.info(
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
)
data["eval_strategy"] = data.get("evaluation_strategy")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data):

View File

@@ -64,15 +64,57 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
return data_set
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
if rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt = sample["prompt"]
completion = sample["completion"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
@@ -94,7 +136,7 @@ def load_prepare_dpo_datasets(cfg):
)
split_datasets.insert(i, ds)
tokenizer = None
tokenizer = load_tokenizer(cfg)
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
@@ -121,7 +163,28 @@ def load_prepare_dpo_datasets(cfg):
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
return concatenate_datasets(split_datasets)
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
return combined_datasets
with zero_first(is_main_process()):
train_is_preprocessed = False

View File

@@ -2,9 +2,11 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:
@@ -236,6 +260,7 @@ def load_tokenized_prepared_datasets(
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
@@ -245,6 +270,7 @@ def load_tokenized_prepared_datasets(
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -324,7 +350,15 @@ def load_tokenized_prepared_datasets(
split=None,
)
else:
ds = load_from_disk(config_dataset.path)
try:
ds = load_from_disk(config_dataset.path)
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
@@ -342,7 +376,7 @@ def load_tokenized_prepared_datasets(
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs = {"split": config_dataset.split}
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -350,6 +384,7 @@ def load_tokenized_prepared_datasets(
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -367,6 +402,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
@@ -377,6 +413,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):

View File

@@ -0,0 +1,25 @@
"""
utils to get GPU info for the current environment
"""
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
)
from accelerate.utils.environment import get_gpu_info
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception: # pylint: disable=broad-except # nosec
pass
return True

View File

@@ -14,6 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version < version.parse("2.4.0"):
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
"""
@staticmethod
@torch.cuda.amp.custom_fwd
@torch_cuda_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
return output
@staticmethod
@torch.cuda.amp.custom_bwd
@torch_cuda_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()

View File

@@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True
# Enable logging hf artifacts in mlflow if value is truthy
if cfg.hf_mlflow_log_artifacts is True:
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,508 @@
"""
Copied from https://github.com/iShohei220/adopt
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
"""
# mypy: ignore-errors
# pylint: skip-file
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast
import torch
from torch import Tensor
from torch.optim.optimizer import (
Optimizer,
ParamsT,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_disable_dynamo_if_unsupported,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_use_grad_for_differentiable,
_view_as_real,
)
__all__ = ["ADOPT", "adopt"]
class ADOPT(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
weight_decay: float = 0.0,
decoupled: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
):
if isinstance(lr, Tensor):
if foreach and not capturable:
raise ValueError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
decoupled=decoupled,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
fused=fused,
)
super().__init__(params, defaults)
if fused:
# TODO: support fused
raise RuntimeError("`fused` is not currently supported")
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
fused = group.setdefault("fused", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val,
dtype=_get_scalar_dtype(is_fused=fused),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("ADOPT does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr
if (
group["foreach"]
and torch.is_tensor(group["lr"])
and not group["capturable"]
):
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = group["betas"]
has_complex = self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
)
adopt(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
decoupled=group["decoupled"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
return loss
def _single_tensor_adopt(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
decoupled: bool,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
assert grad_scale is None and found_inf is None
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
if exp_avg is not None:
exp_avg = torch.view_as_real(exp_avg)
if exp_avg_sq is not None:
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)
step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
exp_avg_sq.addcmul_(grad, grad.conj())
continue
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2:
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
def _multi_tensor_adopt(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
decoupled: bool,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
if len(params) == 0:
return
if isinstance(lr, Tensor) and not capturable:
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
assert grad_scale is None and found_inf is None
assert not differentiable, "_foreach ops don't support autograd"
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
# Handle complex parameters
if has_complex:
_view_as_real(
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
decoupled: bool,
eps: float,
maximize: bool,
):
r"""Functional API that performs ADOPT algorithm computation."""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
# if fused and not torch.jit.is_scripting():
# func = _fused_adopt
# elif foreach and not torch.jit.is_scripting():
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adopt
else:
func = _single_tensor_adopt
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
decoupled=decoupled,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
)

View File

@@ -133,6 +133,8 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
self.len_across_ranks = None
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -195,15 +197,14 @@ class MultipackBatchSampler(BatchSampler):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
if not self.len_across_ranks:
len_batches = self.num_batches()
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
def _len_est(self):
efficiency = (

View File

@@ -1,8 +1,6 @@
"""Module for tokenization utilities"""
import logging
import re
from typing import Dict, List
from termcolor import colored
@@ -68,90 +66,47 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
field_prompt, field_chosen, field_rejected, field_completion = (
"prompt",
"chosen",
"rejected",
"completion",
)
input_tokens = example[field_prompt]
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
labels_chosen = example.get(field_chosen)
labels_rejected = example.get(field_rejected)
labels_completion = example.get(field_completion)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Process tokens
if labels_completion is None:
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
else:
colored_completion = process_tokens_for_rl_debug(
labels_completion, "green", tokenizer, text_only
)
# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
if labels_completion is None:
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
else:
LOG.info(f"COMPLETION RESPONSE: {delimiter.join(colored_completion)}\n\n\n")
return delimiter.join(colored_tokens)
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",
"USER": "human",
"ASSISTANT": "gpt",
"FUNCTION RESPONSE": "tool",
}
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
"""
Converts a ChatML formatted row to a list of messages in ShareGPT format.
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
"""
system_prompt = row.get("system")
if system_prompt:
system_prompt = system_prompt.removeprefix("SYSTEM: ")
chat_str = row["chat"]
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
]
if system_prompt:
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
] + chat_msg_dicts
return chat_msg_dicts
def merge_consecutive_messages(messages):
"""
Merge consecutive messages from the same sender into a single message.
This can be useful with datasets that contain multiple consecutive tool calls.
"""
merged_messages = []
current_from = None
current_message = ""
for msg in messages:
if current_from == msg["from"]:
current_message += msg["value"]
else:
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
current_from = msg["from"]
current_message = msg["value"]
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
return merged_messages

View File

@@ -17,6 +17,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl")
@@ -184,11 +185,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
min_sequence_len=cfg.min_sample_len or 2,
)
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
@@ -203,37 +203,59 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0
prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
@@ -461,6 +483,9 @@ def setup_fsdp_envs(cfg):
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
os.environ["NCCL_P2P_DISABLE"] = "1"
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
@@ -490,7 +515,7 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

35
tests/e2e/conftest.py Normal file
View File

@@ -0,0 +1,35 @@
"""
shared pytest fixtures
"""
import shutil
import tempfile
import pytest
from huggingface_hub import snapshot_download
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download("HuggingFaceTB/SmolLM2-135M")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the model
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset():
# download the model
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture
def temp_dir():
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)

View File

@@ -1,7 +1,6 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path

View File

@@ -0,0 +1,155 @@
"""
E2E tests for multigpu eval
"""
import logging
import os
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPUEval:
"""
Test case for MultiGPU Eval Sample Packing
"""
def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
import logging
import os
import unittest
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import is_hopper
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -25,21 +25,19 @@ AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
# download the model
snapshot_download("TinyLlama/TinyLlama_v1.1")
snapshot_download("HuggingFaceTB/SmolLM2-135M")
class TestMultiGPULlama(unittest.TestCase):
class TestMultiGPULlama:
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
@@ -48,9 +46,7 @@ class TestMultiGPULlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -59,7 +55,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -81,19 +77,23 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_lora_ddp_packed(self, temp_dir):
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"eval_sample_packing": False,
@@ -105,9 +105,7 @@ class TestMultiGPULlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -116,9 +114,9 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 50,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -138,25 +136,166 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp(self, temp_dir):
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -165,9 +304,9 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
@@ -184,7 +323,7 @@ class TestMultiGPULlama(unittest.TestCase):
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
@@ -201,28 +340,29 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp_packed(self, temp_dir):
@pytest.mark.parametrize(
"fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
)
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -231,7 +371,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -250,7 +390,7 @@ class TestMultiGPULlama(unittest.TestCase):
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_state_dict_type": fsdp_state_dict_type,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
@@ -267,37 +407,37 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@pytest.mark.skip("disabled due to upstream issue")
@with_temp_dir
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
"tokenizer_type": "AutoTokenizer",
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": [
"embed_tokens",
"lm_head",
],
# "lora_modules_to_save": [
# "embed_tokens",
# "lm_head",
# ],
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -307,7 +447,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -343,28 +483,29 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_packed(self, temp_dir):
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -373,9 +514,9 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
@@ -396,19 +537,19 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
@@ -421,9 +562,7 @@ class TestMultiGPULlama(unittest.TestCase):
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -432,7 +571,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -455,6 +594,8 @@ class TestMultiGPULlama(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),

View File

@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
import logging
import os
import unittest
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
class TestMultiGPUQwen2(unittest.TestCase):
class TestMultiGPUQwen2:
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_qlora_fsdp_dpo(self, temp_dir):
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2-1.5B",
"base_model": base_model,
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 5,
"warmup_steps": 20,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),

Some files were not shown because too many files have changed in this diff Show More