Compare commits

..

52 Commits

Author SHA1 Message Date
Wing Lian
3afc91fba9 run 2.5.1 test without waiting for 1st e2e 2024-12-07 17:25:16 -05:00
Wing Lian
0689419d25 use pr base tag 2024-12-07 17:25:16 -05:00
Wing Lian
e64c32c0bd push test build 2024-12-07 17:25:16 -05:00
Wing Lian
ec819dde3b attempt to build the test images 2024-12-07 17:25:16 -05:00
Wing Lian
fdf4bb5087 fix default base image 2024-12-07 17:25:16 -05:00
Wing Lian
f67d16268c try with default tag 2024-12-07 17:25:16 -05:00
Wing Lian
684b543aa1 experiment with nvcr pytorch image for torch 2.5.1 2024-12-07 17:25:16 -05:00
Wing Lian
5bef19064b [tests] reset known modules that are patched on each test function end (#2147)
* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
2024-12-07 17:24:46 -05:00
Wing Lian
743ba62bd5 Transformers 4.47.0 (#2138)
* bump transformers and trl

* fix: update trainer.log signature

* fix trl trainer.log interfaces

* broken 🦥 with latest transformers

* skip parent, call grandparent - yeah, super janky

* update HF HUB env var and fix reward trainer log since it doesn't directly override log

* also bump accelerate

* patches for llama ga

* detab the code to check

* fix whitespace for patch check

* play nicely with CI tests since we patch everytime

* fix pop default in case it doesn't exist

* more tweaks to make patches nicer in CI

* fix detab for when there are possibly multiple patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-07 05:03:01 -05:00
Chirag Jain
f9a7748bd8 Fix llama type model check (#2142) [skip ci] 2024-12-07 05:02:32 -05:00
Wing Lian
5e9fa33f3d reduce test concurrency to avoid HF rate limiting, test suite parity (#2128)
* reduce test concurrency to avoid HF rate limiting, test suite parity

* make val_set_size smaller to speed up e2e tests

* more retries for pytest fixture downloads

* val_set_size was too small

* move retry_on_request_exceptions to data utils and add retry strategy

* pre-download ultrafeedback as a test fixture

* refactor download retry into it's own fn

* don't import from data utils

* use retry mechanism now for fixtures
2024-12-06 10:20:20 -05:00
Dan Saunders
08fa133177 Fix broken CLI; remove duplicate metadata from setup.py (#2136)
* Fix broken CLI; remove duplicate metadata from setup.py

* Adding tests.yml CLI check

* updating

* remove test with requests to github due to rate limiting

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2024-12-06 10:19:54 -05:00
Wing Lian
6b3058b2dc upgrade bnb 0.45.0 and peft 0.14.0 (#2126)
* upgrade bnb to lastest release

* update peft to working supporting commit

* bump to latest release of peft==0.14.0
2024-12-06 09:08:55 -05:00
Wing Lian
5726141c4e remove accidentally included symlink (#2131) 2024-12-05 22:37:19 -05:00
Dan Saunders
2f3ebbc44f auto-versioning and adding axolotl.__version__ (#2127)
* auto-versioning and adding axolotl.__version__

* removing file meant for codecov PR

* adding dynamic dependencies, project metadata

* extras/optional-dependencies are dynamic too

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-05 22:12:40 -05:00
Dan Saunders
fc973f4322 CLI Implementation with Click (#2107)
* Initial CLI implementation with click package

* Adding fetch command for pulling examples and deepspeed configs

* Automating default options for CliArgs classes

* Mimicking existing no config behavior

* bugfix in choose_config

* Updating fetch to sync instead of re-download

* bugfix

* isort fix

* fixing yaml isort order

* pre-commit fixes

* simplifying argument parsing -- pass through kwargs to do_cli

* make accelerate launch default for non-preprocess commands

* fixing arg handling

* testing None placeholder approach

* removing hacky --use-gpu argument to preprocess command

* Adding brief README documentation for CLI

* remove (New)

* Initial CLI pytest tests

* progress on CLI pytest

* adding inference CLI tests; cleanup

* Refactor train CLI tests to remove various mocking

* Major CLI test refator; adding remaining CLI codepath test coverage

* pytest fixes

* remove integration markers

* parallelizing examples, deepspeed config downloads; rename test to match other CLI test naming

* moving cli pytest due to isolation issues; cleanup

* testing fixes; various minor improvements

* fix

* tests fix

* Update tests/cli/conftest.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-12-05 22:11:48 -05:00
Wing Lian
e399ba533e fix license header for fix_untrained_tokens from unsloth-zoo (#2129) [skip ci] 2024-12-05 21:20:40 -05:00
Wing Lian
4baf8e5e96 cleanup the readme, add Modal as sponsor (#2130) [skip ci] 2024-12-05 21:19:52 -05:00
Wing Lian
d7d2fd366e update from unsloth-zoo with additional fixes (#2122)
only update tokens seen in the train dataset, log them out explicitly
2024-12-04 12:26:08 -05:00
Wing Lian
e2882dd749 drop unnecessary BNB_CUDA_VERSION env var from docker as it just results in warnings (#2121) [skip ci]
* drop unnecessary BNB_CUDA_VERSION env var from docker as it just results in warnings

* make sure to run tests when cicd Dockerfile changes
2024-12-04 12:25:47 -05:00
Wing Lian
a1790f2652 replace tensorboard checks with helper function (#2120) [skip ci]
* replace tensorboard checks with helper function

* move helper function

* use relative
2024-12-03 21:06:20 -05:00
Wing Lian
418ad2b586 add missing fixture decorator for predownload dataset (#2117) [skip ci]
* add missing fixture decorator for predownload dataset

* also pre download the tokenizer files
2024-12-03 18:08:46 -05:00
Wing Lian
d87df2c776 prepare plugins needs to happen so registration can occur to build the plugin args (#2119)
* prepare plugins needs to happen so registration can occur to build the plugin args

use yaml.dump

include dataset and more assertions

* attempt to manually register plugins rather than use fn

* fix fixture

* remove fixture

* move cli test to patched dir

* fix cce validation
2024-12-03 15:06:09 -05:00
Wing Lian
1ef70312ba fix optimizer reset for relora sft (#1414)
* fix optimizer reset

* set states to reset for 8bit optimizers and handle quantile runtime error for embeddings

* fix relora test to check grad_norm

* use flash attn for relora and tweak hyperparams for test

* fix messages field for test dataset
2024-12-03 08:58:23 -05:00
NanoCode012
81ef3e45f7 fix(readme): update cuda instructions during preprocess (#2114) [skip ci] 2024-12-03 08:58:03 -05:00
NanoCode012
bd8436bc6e feat: add cut_cross_entropy (#2091)
* feat: add cut_cross_entropy

* fix: add to input

* fix: remove from setup.py

* feat: refactor into an integration

* chore: ignore lint

* feat: add test for cce

* fix: set max_steps for liger test

* chore: Update base model following suggestion

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* chore: update special_tokens following suggestion

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* chore: remove with_temp_dir following comments

* fix: plugins aren't loaded

* chore: update quotes in error message

* chore: lint

* chore: lint

* feat: enable FA on test

* chore: refactor get_pytorch_version

* fix: lock cce commit version

* fix: remove subclassing UT

* fix: downcast even if not using FA and config check

* feat: add test to check different attentions

* feat: add install to CI

* chore: refactor to use parametrize for attention

* fix: pytest not detecting test

* feat: handle torch lower than 2.4

* fix args/kwargs to match docs

* use release version cut-cross-entropy==24.11.4

* fix quotes

* fix: use named params for clarity for modal builder

* fix: handle install from pip

* fix: test check only top level module install

* fix: re-add import check

* uninstall existing version if no transformers submodule in cce

* more dataset fixtures into the cache

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-03 08:22:22 -05:00
Wing Lian
fc6188cd76 fix merge conflict of duplicate max_steps in config for relora (#2116) 2024-12-03 07:42:41 -05:00
Wing Lian
b9bb02406a fix so inference can be run against quantized models without adapters (#1834)
* fix so inference can be run against quantized models without adapters

* Update error msg [skip e2e]

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-03 00:02:38 -05:00
Sunny Liu
ff4794cd8e Add ds model card, rebased (#2101) [skip ci]
* rebased add_ds_model_card

* manual rebasing

* fix redundancy

* lint

* include case when ds_tag is none

* conform to kwargs in create_model_card
2024-12-03 00:02:02 -05:00
NanoCode012
822c904092 fix(vlm): handle legacy conversation data format and check image in data (#2018) [skip ci]
* fix: handle legacy conversation data format and check image in data

* feat: add test for llama vision

* feat: add max_steps to test

* fix: incorrect indent and return preprocess

* feat: use smaller model and dataset

* chore: add extra config for sharegpt dataset
2024-12-03 00:01:31 -05:00
Sunny Liu
d5f58b6509 Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-12-02 20:15:39 -05:00
Wing Lian
9f6d0b5587 use pytest sugar and verbose for more info during ci (#2112) [skip ci]
* use pytest sugar and verbose for more info during ci

* also run test suite when test requirements or cicd.sh changes

* also on PR too
2024-12-02 20:14:40 -05:00
Wing Lian
53963c792c make the eval size smaller for the resume test (#2111) [skip ci] 2024-12-02 18:32:29 -05:00
Wing Lian
a4f4a56d77 build causal_conv1d and mamba-ssm into the base image (#2113)
* build causal_conv1d and mamba-ssm into the base image

* also build base images on changes to Dockerfile-base and base workflow yaml
2024-12-02 18:27:46 -05:00
Wing Lian
ce5bcff750 various tests fixes for flakey tests (#2110)
* add mhenrichsen/alpaca_2k_test with revision dataset download fixture for flaky tests

* log slowest tests

* pin pynvml==11.5.3

* fix load local hub path

* optimize for speed w smaller models and val_set_size

* replace pynvml

* make the resume from checkpoint e2e faster

* make tests smaller
2024-12-02 17:28:58 -05:00
Oliver Molenschot
b620ed94d0 Add Exact Deduplication Feature to Preprocessing Pipeline (#2072)
* Add example YAML file for training Mistral using DPO

* added deduplication code

* Add exact deduplication feature and update examples

* Improve deduplication for train/eval overlap

Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements.

* Improve deduplication for train/eval overlap

Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements.

* Apply suggestions from code review

To handle the original case where we do not do deduplication

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Improve false collision detection to ensure dataset integrity

- Added test cases to simulate and verify handling of forced hash collisions between datasets.
- Ensured that datasets with identical hashes but different content are correctly identified, preventing incorrect deduplication.
- Updated unit tests to include scenarios where collisions occur across both training and evaluation datasets, as well as within a single dataset.

* Moved the constants file to the tests folder

- Relocated `constants.py` to the `tests` folder to improve modularity and maintain a clear separation between source and test files.
- Renamed `cicd/tests.py` to `cicd/cicd_tests.py` to resolve a conflict with `tests/__init__.py`, which caused Mypy to fail due to duplicate module names.
- Updated all references to `cicd.tests` in the codebase to `cicd.cicd_tests` to reflect the renaming and ensure compatibility.
- These changes ensure Mypy passes the pre-commit hook and maintain alignment with the project's structure.

* revert some changes from previous commit and fix relative import

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-02 08:47:10 -05:00
Wing Lian
5f1d98e8fc add e2e tests for Unsloth qlora and test the builds (#2093)
* see if unsloth installs cleanly in ci

* check unsloth install on regular tests, not sdist

* fix ampere check exception for ci

* use cached_property instead

* add an e2e test for unsloth qlora

* reduce seq len and mbsz to prevent oom in ci

* add checks for fp16 and sdp_attention

* pin unsloth to a specific release

* add unsloth to docker image too

* fix flash attn xentropy patch

* fix loss, add check for loss when using fa_xentropy

* fix special tokens for test

* typo

* test fa xentropy with and without gradient accum

* pr feedback changes
2024-11-29 20:38:49 -05:00
Wing Lian
1cf7075d18 support seperate lr for embeddings, similar to loraplus (#1910) [skip ci]
* support seperate lr for embeddings, similar to loraplus

* add test case for train w lr embedding scale

* use kwarg for optimizer

* make sure to handle the optimizer creation

* make sure to handle for embedding_lr too

* use smollm for e2e, check for embeddings lr first before wdecay
2024-11-29 20:38:20 -05:00
NanoCode012
f4cabc2351 fix: ds3 and fsdp lmbench eval (#2102) [ski[p ci]
* fix: ds3 and fsdp lmbench eval

* chore: update comment

* fix: test signature
2024-11-29 20:37:49 -05:00
Wing Lian
6e0fb4a6b2 add finetome dataset to fixtures, check eval_loss in test (#2106) [skip ci]
* add finetome dataset to fixtures, check eval_loss in test

* add qwen 0.5b to pytest session fixture
2024-11-29 20:37:32 -05:00
Wing Lian
724b660d56 move shared pytest conftest to top level tests (#2099) [skip ci]
* move shared pytest conftest to top level tests

* add __init__ so mypy doesn't choke on multiple conftests
2024-11-22 15:05:42 -05:00
Aman Karmani
51c9e1a035 .gitignore improvements (#349) [skip ci] 2024-11-22 11:08:54 -05:00
Sunny Liu
45c0825587 updated colab notebook (#2074)
* updated colab notebook

* update pip installtation

* cleared cell output

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* modified notebook

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* cleared cell output

* cleared unnecessary logs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-11-22 10:09:10 -05:00
Wing Lian
94fc223f6c actions/create-release is unmaintained, and doesn't create proper release notes (#2098) [skip ci] 2024-11-21 14:32:41 -05:00
Sunny Liu
151abb7a67 fix None-type not iterable error when deepspeed is left blank w/ use_… (#2087)
* fix None-type not iterable error when deepspeed is left blank w/ use_reentrant: false and qlora

* added unit test[skip e2e]

* corrected test case[skip e2e]

* assert warning message [skip e2e]

* assert warning message [skip e2e]

* corrected test cases [skip e2e]

* lint
2024-11-21 13:36:51 -05:00
Sunny Liu
bf416bdfd0 bump_liger_0.4.2 (#2096) 2024-11-21 13:24:52 -05:00
Mengqing Cao
838b74d05b Add Ascend NPU support (#1758) 2024-11-20 21:28:41 -05:00
Wing Lian
2e99bb303e fix inference when no chat_template is set, fix unsloth dora check (#2092)
* fix inference when no chat_template is set, fix unsloth dora check

* remove old unsloth version check

* update docs on installing unsloth
2024-11-20 14:07:54 -05:00
Chirag Jain
68a26f1005 Fix duplication of plugin callbacks (#2090) 2024-11-20 14:06:08 -05:00
Wing Lian
db51a9e4cb use pep440 instead of semver (#2088) [skip ci] 2024-11-19 15:02:10 -05:00
Wing Lian
8961364bc9 release 0.5.2 (#2086) 2024-11-19 12:44:42 -05:00
Wing Lian
e9c3a2aec0 add missing dunder-init for monkeypatches and add tests for install from sdist (#2085)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 121, 12.1.1, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* add missing dunder-init for monkeypatches and add tests for install from sdist

* fix gha name

* reduce matrix for sdist test
2024-11-19 12:43:30 -05:00
109 changed files with 5271 additions and 912 deletions

View File

@@ -1,6 +1,16 @@
name: ci-cd-base
on:
push:
branches:
- "main"
paths:
- 'Dockerfile-base'
- '.github/workflows/base.yml'
pull_request:
paths:
- 'Dockerfile-base'
- '.github/workflows/base.yml'
workflow_dispatch:
jobs:
@@ -12,36 +22,38 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.10"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.10"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# - cuda: "121"
# cuda_version: 12.1.1
# cudnn_version: 8
# python_version: "3.10"
# pytorch: 2.3.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# from_base_img: ""
# from_base_tag: ""
# - cuda: "121"
# cuda_version: 12.1.1
# cudnn_version: 8
# python_version: "3.11"
# pytorch: 2.3.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# from_base_img: ""
# from_base_tag: ""
# - cuda: "124"
# cuda_version: 12.4.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: 2.4.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# from_base_img: ""
# from_base_tag: ""
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
from_base_img: nvcr.io/nvidia/pytorch
from_base_tag: 24.10-py3
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -51,7 +63,7 @@ jobs:
with:
images: |
winglian/axolotl-base
axolotlai/axolotl-base
# axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
@@ -64,7 +76,8 @@ jobs:
with:
context: .
file: ./docker/Dockerfile-base
push: ${{ github.event_name != 'pull_request' }}
push: true
# push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
@@ -74,3 +87,5 @@ jobs:
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
BASE_IMAGE=${{ matrix.from_base_img || '' }}
BASE_TAG=${{ matrix.from_base_tag || '' }}

View File

@@ -49,7 +49,7 @@ jobs:
axolotlai/axolotl
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
@@ -116,7 +116,7 @@ jobs:
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
@@ -163,7 +163,7 @@ jobs:
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:

View File

@@ -13,19 +13,10 @@ jobs:
permissions:
contents: write
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
- name: Create release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest

View File

@@ -23,9 +23,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20
steps:
@@ -55,11 +61,18 @@ jobs:
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |

View File

@@ -8,11 +8,17 @@ on:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
pull_request:
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
@@ -39,9 +45,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20
steps:
@@ -67,73 +79,133 @@ jobs:
run: |
pip3 show torch
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 1
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
- name: Checkout
- name: Check out repository code
uses: actions/checkout@v4
- name: Install Python
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
modal run cicd.tests
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
# docker-e2e-tests-1st:
# if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# # this job needs to be run on self-hosted GPU runners...
# runs-on: [self-hosted, modal]
# timeout-minutes: 90
# needs: [pre-commit, pytest, pytest-sdist]
#
# strategy:
# fail-fast: false
# matrix:
# include:
# - cuda: 124
# cuda_version: 12.4.1
# python_version: "3.11"
# pytorch: 2.4.1
# num_gpus: 1
# axolotl_extras:
# steps:
# - name: Checkout
# uses: actions/checkout@v4
# - name: Install Python
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# - name: Install Modal
# run: |
# python -m pip install --upgrade pip
# pip install modal==0.63.64 jinja2
# - name: Update env vars
# run: |
# echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
# echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
# echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
# echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
# echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
# echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
# - name: Run tests job on Modal
# run: |
# modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
# needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
# - cuda: 121
# cuda_version: 12.1.1
# python_version: "3.10"
# pytorch: 2.3.1
# num_gpus: 1
# axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -153,7 +225,7 @@ jobs:
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "BASE_TAG=pr-2139-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV

3
.gitignore vendored
View File

@@ -182,3 +182,6 @@ submit.sh
typings/
out/
# vim
*.swp

253
README.md
View File

@@ -41,9 +41,12 @@ Features:
## Table of Contents
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Axolotl supports](#axolotl-supports)
- [Quickstart ⚡](#quickstart-)
- [Usage](#usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [Axolotl supports](#axolotl-supports)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
@@ -75,14 +78,6 @@ Features:
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
- [Badge ❤🏷️](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
</td>
<td>
@@ -105,6 +100,127 @@ Features:
</tr>
</table>
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip3 install packaging ninja
pip3 install -e '.[flash-attn,deepspeed]'
```
### Usage
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
### Axolotl CLI
If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
# finetune lora
axolotl train examples/openllama-3b/lora.yml
# inference
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out"
# gradio
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.
```bash
# Fetch example YAML files (stores in "examples/" folder)
axolotl fetch examples
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs
# Optionally, specify a destination folder
axolotl fetch examples --dest path/to/folder
```
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Sponsors 🤝❤
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
---
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
---
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
@@ -130,41 +246,6 @@ Features:
❌: not supported
❓: untested
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip3 install packaging ninja
pip3 install -e '.[flash-attn,deepspeed]'
```
### Usage
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
## Advanced Setup
### Environment
@@ -682,86 +763,6 @@ See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, al
## Need help? 🙋
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
Need dedicated support? Please contact us at [✉wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Community Showcase
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b-fixed)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
PocketDoc Labs
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
[NanoCode012](https://github.com/NanoCode012), [tmm1](https://github.com/tmm1),
[mhenrichsen](https://github.com/mhenrichsen), [casper-hansen](https://github.com/casper-hansen),
[hamelsmu](https://github.com/hamelsmu) and many more who help us accelerate forward by fixing bugs, answering
community questions and implementing new features. Axolotl needs donations from sponsors for the compute needed to
run our unit & integration tests, troubleshooting community issues, and providing bounties. If you love axolotl,
consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsors/OpenAccess-AI-Collective),
[Ko-fi](https://ko-fi.com/axolotl_ai) or reach out directly to
[wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org).
---
#### 💎 Diamond Sponsors - [Contact directly](mailto:wing@openaccessaicollective.org)
---
#### 🥇 Gold Sponsors - $5000/mo
---
#### 🥈 Silver Sponsors - $1000/mo
---
#### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
---
Need dedicated support? Please contact us at [wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.

View File

@@ -1,10 +1,9 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
@@ -37,6 +36,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -1,6 +1,7 @@
#!/bin/bash
set -e
pytest -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -40,6 +40,7 @@ with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)

View File

@@ -1,11 +1,11 @@
ARG BASE_IMAGE=axolotlai/axolotl-base
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG
FROM $BASE_IMAGE:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -26,6 +26,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install pytest

View File

@@ -3,7 +3,10 @@ ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG BASE_IMAGE=nvidia/cuda
ARG DEFAULT_TAG=${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}
ARG BASE_TAG=""
FROM ${BASE_IMAGE:-nvidia/cuda}:${BASE_TAG:-${DEFAULT_TAG}} AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
@@ -29,7 +32,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -5,7 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ARG GITHUB_REF="main"

View File

@@ -162,6 +162,9 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
@@ -406,7 +409,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd

View File

@@ -11,12 +11,10 @@ standard industry baselines.
### Installation
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
The following will install the correct unsloth and extras from source.
```bash
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
python scripts/unsloth_install.py | sh
```
### Using unsloth w Axolotl

View File

@@ -2,19 +2,15 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "AKjdG7tbTb-n"
},
"metadata": {},
"source": [
"# Example notebook for running Axolotl on google colab"
"## Setting up"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RcbNpOgWRcii"
},
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
@@ -22,82 +18,76 @@
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"metadata": {},
"outputs": [],
"source": [
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.7.0.post2\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
"!pip install axolotl[deepspeed]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BW2MFr7HTjub"
},
"metadata": {},
"source": [
"## Create an yaml config file"
"## Hugging Face login (optional)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pkF2dSoQEUN"
},
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"base_model: NousResearch/Meta-Llama-3.1-8B\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" - path: tatsu-lab/alpaca\n",
" type: alpaca\n",
"dataset_prepared_path:\n",
"dataset_prepared_path: last_run_prepared\n",
"val_set_size: 0.05\n",
"output_dir: ./outputs/qlora-out\n",
"output_dir: ./outputs/lora-out\n",
"\n",
"sequence_len: 2048\n",
"sample_packing: true\n",
"eval_sample_packing: true\n",
"pad_to_sequence_len: true\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 4096\n",
"sample_packing: true\n",
"eval_sample_packing: false\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"lora_modules_to_save:\n",
" - embed_tokens\n",
" - lm_head\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
@@ -105,12 +95,12 @@
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"gradient_accumulation_steps: 4\n",
"micro_batch_size: 2\n",
"num_epochs: 4\n",
"optimizer: paged_adamw_32bit\n",
"gradient_accumulation_steps: 2\n",
"micro_batch_size: 1\n",
"num_epochs: 1\n",
"optimizer: paged_adamw_8bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 0.0002\n",
"learning_rate: 2e-5\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
@@ -121,13 +111,15 @@
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: true\n",
"flash_attention: false\n",
"sdp_attention: true\n",
"\n",
"warmup_steps: 10\n",
"evals_per_epoch: 4\n",
"warmup_steps: 1\n",
"max_steps: 25\n",
"evals_per_epoch: 1\n",
"eval_table_size:\n",
"saves_per_epoch: 1\n",
"debug:\n",
"deepspeed:\n",
@@ -135,9 +127,10 @@
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
"\n",
" pad_token: <|end_of_text|>\n",
"\"\"\"\n",
"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
"\n",
@@ -146,31 +139,124 @@
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)\n"
" yaml.dump(yaml_dict, file)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bidoj8YLTusD"
},
"metadata": {},
"source": [
"## Launch the training"
"Above we have a configuration file with base LLM model and datasets specified, among many other things. Axolotl can automatically detect whether the specified datasets are on HuggingFace repo or local machine.\n",
"\n",
"The Axolotl configuration options encompass model and dataset selection, data pre-processing, and training. Let's go through them line by line:\n",
"\n",
"* \"base model\": String value, specifies the underlying pre-trained LLM that will be used for finetuning\n",
"\n",
"Next we have options for model weights quantization. Quantization allows for reduction in occupied memory on GPUs.\n",
"\n",
"* \"load_in_8bit\": Boolean value, whether to quantize the model weights into 8-bit integer.\n",
"\n",
"* \"load_in_4bit\": Boolean value, whether to quantize the model weights into 4-bit integer.\n",
"\n",
"* \"strict\": Boolean value. If false, it allows for overriding established configuration options in the yaml file when executing in command-line interface.\n",
"\n",
"* \"datasets\": a list of dicts that contain path and type of data sets as well as other optional configurations where datasets are concerned. Supports multiple datasets.\n",
"\n",
"* \"val_set_size\": Either a float value less than one or an integer less than the total size of dataset. Sets the size of validation set from the whole dataset. If float, sets the proportion of the dataset assigned for validation. If integer, sets the direct size of validation set.\n",
"\n",
"* \"output_dir\": String value. Path of trained model.\n",
"\n",
"For data preprocessing:\n",
"\n",
"* \"sequence_len\": Integer. Specifies the maximum sequence length of the input. Typically 2048 or less.\n",
"\n",
"* \"pad_to_sequence_len\": Boolean. Padding input to maximum sequence length.\n",
"\n",
"* \"sample_packing\": Boolean. Specifies whether to use multi-packing with block diagonal attention.\n",
"\n",
"* \"special_tokens\": Python dict, optional. Allows users to specify the additional special tokens to be ignored by the tokenizer.\n",
"\n",
"For LoRA configuration and its hyperparamters:\n",
"\n",
"* \"adapter\": String. Either \"lora\" or \"qlora\", depending on user's choice.\n",
"\n",
"* \"lora_model_dir\": String, Optional. Path to directory that contains LoRA model, if there is already a trained LoRA model the user would like to use.\n",
"\n",
"* \"lora_r\": Integer. Refers to the rank of LoRA decomposition matrices. Higher value will reduce LoRA efficiency. Recommended to be set to 8.\n",
"\n",
"* \"lora_alpha\": Integer. Scale the weight matrices by $\\frac{\\text{lora_alpha}}{\\text{lora_r}}$Recommended to be fixed at 16.\n",
"\n",
"* \"lora_dropout\": Float that is 1 or less. The dropout probability of a lora layer.\n",
"\n",
"* \"lora_target_linear\": Boolean. If true, lora will target all linear modules in the transformers architecture.\n",
"\n",
"* \"lora_modules_to_save\": If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n",
"\n",
"See [LoRA](https://arxiv.org/abs/2106.09685) for detailed explanation of LoRA implementation.\n",
"\n",
"For the training configurations:\n",
"\n",
"* \"gradient_accumulation_steps\": Integer. The number of steps over which to accumulate gradient for batch training. E.g. if 2, backprop is performed every two steps.\n",
"\n",
"* \"micro_batch_size\": Integer. Batch size per gpu / gradient_accumulation_steps\n",
"\n",
"* \"num_epochs\": Integer. Number of epochs. One epoch is when training has looped over every batch in the whole data set once.\n",
"\n",
"* \"optimizer\": The optimizer to use for the training.\n",
"\n",
"* \"learning_rate\": The learning rate.\n",
"\n",
"* \"lr_scheduler\": The learning rate scheduler to use for adjusting learning rate during training.\n",
"\n",
"* \"train_on_inputs\": Boolean. Whether to ignore or include the user's prompt from the training labels.\n",
"\n",
"* \"group_by_length\": Boolean. Whether to group similarly sized data to minimize padding.\n",
"\n",
"* \"bf16\": Either \"auto\", \"true\", or \"false\". Whether to use CUDA bf16 floating point format. If set to \"auto\", will automatically apply bf16 should the gpu supports it.\n",
"\n",
"* \"fp16\": Optional. Specifies whether to use CUDA fp16. Automatically set to true if \"bf16\" is set to true. Otherwise false.\n",
"\n",
"* \"tf32\": Boolean. Whether to use CUDA tf32. Will override bf16.\n",
"\n",
"* \"gradient_checkpointing\": Boolean. Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n",
"\n",
"* \"gradient_checkpointing_kwargs\": Python Dict. Fed into the trainer.\n",
"\n",
"* \"logging_steps\": Integer. Log training information over every specified number of steps.\n",
"\n",
"* \"flash_attention\": Boolean. Whether to use the [flash attention](https://github.com/Dao-AILab/flash-attention) mechanism.\n",
"\n",
"* \"sdp_attention\": Boolean. Whether to use the Scaled Dot Product attention mechanism (the attention mechanism in the [original implementation](https://arxiv.org/abs/1706.03762) of transformers.)\n",
"\n",
"* \"warmup_steps\": Integer. The number of pre-training steps where a very low learning rate is used.\n",
"\n",
"* \"evals_per_epoch\": Integer. Number of evaluations to be performed within one training epoch.\n",
"\n",
"* \"saves_per_epoch\": Integer. Number of times the model is saved in one training epoch.\n",
"\n",
"* \"weight_decay\": Positive Float. Sets the \"strength\" of weight decay (i.e. setting the coefficient of L2 regularization)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above is but a snippet aiming to get users familiarized with the types of streamlined configuration options axolotl provides. For a full list of configuration options, see [here](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"metadata": {},
"outputs": [],
"source": [
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
@@ -178,7 +264,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
"Predict with trained model"
]
},
{
@@ -187,36 +273,85 @@
"metadata": {},
"outputs": [],
"source": [
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
" --lora_model_dir=\"./outputs/lora-out\" --gradio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deeper Dive"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is also helpful to gain some familiarity over some of the core inner workings of axolotl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration Normalization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Axolotl uses a custom Dict class, called ```DictDefault```\n",
"to store configurations specified in the yaml configuration file (into a Python variable named ```cfg```). The definition for this custom Dict can be found in the [utils/dict.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/dict.py)\n",
"\n",
"```DictDefault``` is amended such that calling a missing key from it will result in a ```None``` return type. This is important because if some configuration options aren't specified by the user, the ```None``` type allows Axolotl to perform boolean operations to determine the default settings for missing configurations. For more examples on how this is done, check out [utils/config/__init__.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/__init__.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading Models, Tokenizers, and Trainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we inspect [cli.train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/cli/train.py), we will find that most of the heavy lifting were done by the function ```train()``` which is itself imported from [src/axolotl/train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/train.py).\n",
"\n",
"```train()``` takes care of loading the appropriate tokenizer and pre-trained model through ```load_model()``` and ```load_tokenizer()``` from [src/axolotl/utils/models.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/models.py) respectively.\n",
"\n",
"```load_tokenizer()``` loads in the appropriate tokenizer given the desired model, as well as chat templates.\n",
"\n",
"```ModelLoader``` class follows after tokenizer has been selected. It will automatically discern the base model type, load in the desired model, as well as applying model-appropriate attention mechanism modifications (e.g. flash attention). Depending on which base model the user chooses in the configuration, ```ModelLoader``` will utilize the corresponding \"attention hijacking\" script. For example, if the user specified the base model to be ```NousResearch/Meta-Llama-3.1-8B```, which is of llama type, and set ```flash_attn``` to ```True```, ```ModelLoader``` will load in [llama_attn_hijack_flash.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py). For a list of supported attention hijacking, please refer to the directory [/src/axolotl/monkeypatch/](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch)\n",
"\n",
"Another important operation encompassed in ```train()``` is setting up the training that takes into account of user-specified traning configurations (e.g. num_epochs, optimizer) through the use of ```setup_trainer()``` from [/src/axolotl/utils/trainer.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/trainer.py), which in turn relies on modules from [/src/axolotl/core/trainer_builder.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/core/trainer_builder.py).\n",
"```trainer_builder.py``` provides a list of trainer object options bespoke for the task type (Causal or Reinforcement learning ('dpo', 'ipo', 'kto') )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Monkey patch\n",
"\n",
"The [Monkey patch directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch) is where model architecture/optimization patching scripts are stored (these are modifications that are not implemented in the official releases, hence the name monkey patch). It includes attention jacking, ReLoRA, and unsloth optimization."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 2
}

View File

@@ -0,0 +1,95 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_exact_deduplication: true
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -0,0 +1,76 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
dataset_exact_deduplication: true
test_value: true
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

19
pyproject.toml Normal file
View File

@@ -0,0 +1,19 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"
[project]
name = "axolotl"
dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
[project.scripts]
axolotl = "axolotl.cli.main:main"
[project.urls]
Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm]

View File

@@ -2,4 +2,3 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,3 +1,5 @@
pytest
pytest-xdist
pytest-retry
pytest-sugar
tbparse

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.3
peft==0.14.0
transformers==4.47.0
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.1.0
bitsandbytes==0.45.0
accelerate==1.2.0
datasets==3.1.0
deepspeed==0.15.4
pydantic==2.6.3
@@ -26,14 +26,14 @@ numpy>=1.24.4,<=2.0.1
evaluate==0.4.1
scipy
scikit-learn==1.4.2
pynvml
nvidia-ml-py==12.560.30
art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post2
triton>=2.3.0
liger-kernel==0.4.1
liger-kernel==0.4.2
mamba-ssm==1.2.0.post1
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.12.0
trl==0.12.1
zstandard==0.22.0
fastcore

View File

@@ -0,0 +1,28 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
try:
import torch
except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
print("")
sys.exit(0)
cce_spec = importlib.util.find_spec("cut_cross_entropy")
UNINSTALL_PREFIX = ""
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -0,0 +1,36 @@
# noqa
# pylint: skip-file
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
)

View File

@@ -1,5 +1,4 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
@@ -93,16 +92,16 @@ def parse_requirements():
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.5.1",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require={
"flash-attn": [
"flash-attn==2.7.0.post2",

View File

@@ -0,0 +1,8 @@
"""Axolotl - Train and fine-tune large language models"""
try:
from importlib.metadata import version
__version__ = version("axolotl")
except ImportError:
__version__ = "unknown"

View File

@@ -27,14 +27,17 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
@@ -97,8 +100,8 @@ def print_dep_versions():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)
@@ -136,7 +139,7 @@ def check_remote_config(config: Union[str, Path]):
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
@@ -199,6 +202,10 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -373,7 +380,7 @@ def choose_config(path: Path):
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
@@ -384,7 +391,7 @@ def choose_config(path: Path):
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
@@ -419,17 +426,14 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg.axolotl_config_path = config
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
@@ -437,6 +441,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)
prepare_optim_env(cfg)

View File

@@ -2,6 +2,7 @@
CLI to run inference on a trained model
"""
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -16,10 +17,10 @@ from axolotl.cli import (
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

231
src/axolotl/cli/main.py Normal file
View File

@@ -0,0 +1,231 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
)
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group()
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.merge_sharded_fsdp_weights import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
"""
fetch_from_github(f"{directory}/", dest)
def main():
cli()
if __name__ == "__main__":
main()

View File

@@ -2,6 +2,7 @@
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -11,7 +12,7 @@ from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))

View File

@@ -177,7 +177,7 @@ def merge_fsdp_weights(
state.wait_for_everyone()
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))

218
src/axolotl/cli/utils.py Normal file
View File

@@ -0,0 +1,218 @@
"""Utility methods for axoltl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
import logging
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
LOG = logging.getLogger("axolotl.cli.utils")
def add_options_from_dataclass(config_class: Type[Any]):
"""Create Click options from the fields of a dataclass."""
def decorator(function):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]):
"""Create Click options from the fields of a Pydantic model."""
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
"""Build command list from base command and options."""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])
return cmd
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> Tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha)
raw_base_url: Base URL for raw GitHub content
dest_path: Local destination directory
dir_prefix: Directory prefix to filter files
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "new"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
max_workers: Maximum number of concurrent downloads
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")

View File

@@ -3,36 +3,88 @@ helper functions for fixing the embeddings/tokenizer
"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
# GNU LESSER GENERAL PUBLIC LICENSE
# Version 3, 29 June 2007
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
# Everyone is permitted to copy and distribute verbatim copies
# of this license document, but changing it is not allowed.
import gc
import itertools
import logging
from collections import Counter
import datasets
import numpy as np
import torch
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
@torch.inference_mode
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
@torch.inference_mode()
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
):
"""
Many of the newer models have reserved tokens that are not trained.
Llama-3 for eg has untrained vectors in the base model.
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
We reset them to the mean of the rest of the tokens
"""
# Code licensed under LGPL
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
chat_template = getattr(tokenizer, "chat_template", None)
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
# Ignore some model checks for now
if not ignored_tokenizer_names:
ignored_tokenizer_names = []
if (
model.config._name_or_path # pylint: disable=protected-access
in ignored_tokenizer_names
):
return
# Sometimes the sizes can be different like in vision models
# Ie <image> is in input, but not in output
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
embedding_matrix = embedding_matrix[:, :min_size]
lm_head_matrix = lm_head_matrix[:, :min_size]
# Get untrained tokens
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
# Check lm_head as well
# Does NOT work for Llama 3.1!!
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
# We instead check for repeated vectors
lm_head_where = torch.where(indicator_untrained1)[0]
lm_head_bad = lm_head_matrix[lm_head_where]
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
counter = Counter()
for row in lm_head_bad:
counter[hash(row.data.tobytes())] += 1
counter = Counter({k: c for k, c in counter.items() if c >= 2})
lm_head_where = lm_head_where.cpu().numpy()
final_bad_lm_head = []
for j, row in enumerate(lm_head_bad):
if hash(row.data.tobytes()) in counter:
final_bad_lm_head.append(lm_head_where[j])
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
indicator_untrained2[final_bad_lm_head] = True
# Combine both checks
indicator_untrained = indicator_untrained1 & indicator_untrained2
# Remove pad token possibility
if hasattr(tokenizer, "pad_token_id"):
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
indicator_untrained[pad_token_id] = False
where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
@@ -40,10 +92,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
# Get set and actual tokens
where_untrained = where_untrained.tolist()
if len(where_untrained) == 0:
return False
return
# Remove untrained indices where it's longer
where_untrained_set = frozenset(where_untrained)
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
# Remove None items in actual_bad_tokens
@@ -53,10 +104,14 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
if_bad_first = False
if_bad_second = False
# Check tokenizer's chat template for any untrained tokens
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is not None:
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
if isinstance(train_dataset, datasets.IterableDataset):
# Skip the check, since the code below assumes
# an indexable dataset
return
# Check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
@@ -83,7 +138,69 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
# Check if bad tokens exists!
if not if_bad_first and not if_bad_second:
return False
return
# Check if lm_head / embed_token are trainable!
bad_not_trainable = False
if not embedding_matrix.requires_grad:
bad_not_trainable = True
if not lm_head_matrix.requires_grad:
bad_not_trainable = True
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
final_bad_items = []
# Re-check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# Re-check last 250
left = max(size_dataset - 250, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# If no bad tokens, possibly chat template itself has issues?
if len(final_bad_items) == 0:
# Recheck 2000 and last 2000 items
size_dataset = len(train_dataset)
size = min(size_dataset, 2000)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# Re-check last 2000
left = max(size_dataset - 2000, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# Most likely false signal!
if len(final_bad_items) == 0:
return
raise ValueError(
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
)
# Count all the possible bad tokens
final_counts = np.zeros(
@@ -97,6 +214,23 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
# Get counts for untrained tokens
counts_untrained = final_counts[where_untrained]
# Identify untrained tokens seen in train_dataset
indices_seen_in_train = np.where(counts_untrained > 0)[0]
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
if len(tokens_to_update) == 0:
LOG.info(
"No untrained tokens found in train_dataset. No embeddings were modified."
)
return
# Log the token IDs that are being rescaled
LOG.info(
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
)
# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
@@ -113,38 +247,26 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
mean_embedding = sum_embedding / n_trained
mean_lm_head = sum_lm_head / n_trained
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
# Compute scaling for tokens to update
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
mean_embedding = (
mean_embedding.repeat(
(
n_untrained,
1,
)
)
* scaling
)
mean_lm_head = (
mean_lm_head.repeat(
(
n_untrained,
1,
)
)
* scaling
)
where_null = scaling.ravel() == 0
mean_embedding[where_null] = 0
mean_lm_head[where_null] = 0
# Set them to the mean
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
# Prepare mean embeddings for tokens to update
mean_embedding_repeated = (
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)
mean_lm_head_repeated = (
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)
# Update embeddings only for tokens seen in train_dataset
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
embedding_matrix.dtype
)
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return True
return

View File

@@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
@dataclass
class AxolotlTrainingMixins:
"""
@@ -220,6 +236,14 @@ class AxolotlTrainingMixins:
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
@@ -386,7 +410,7 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -410,10 +434,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
@@ -435,6 +461,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
@@ -449,30 +477,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
@@ -485,6 +542,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -516,7 +580,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
@@ -871,6 +937,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@@ -888,13 +957,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return res
def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -902,7 +973,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -994,8 +1065,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, **kwargs):
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
def create_optimizer(self):
@@ -1034,6 +1106,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@@ -1082,6 +1157,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1090,6 +1177,18 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1098,6 +1197,45 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1106,6 +1244,18 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1114,6 +1264,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
class TrainerBuilderBase(abc.ABC):
"""
@@ -1212,11 +1368,17 @@ class TrainerBuilderBase(abc.ABC):
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
@@ -1263,7 +1425,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1301,17 +1463,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
@@ -1575,6 +1727,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
@@ -1759,6 +1914,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
trainer_kwargs["tokenizer"] = self.tokenizer
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
@@ -2032,6 +2191,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,

View File

@@ -40,7 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs
**generation_kwargs,
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)

View File

@@ -0,0 +1,325 @@
Acknowledgements
Portions of this Cut Cross Entropy Software may utilize the following copyrighted
material, the use of which is hereby acknowledged.
------
PyTorch
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
Triton
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
Transformers
Copyright 2018- The Hugging Face team. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,47 @@
Copyright (C) 2024 Apple Inc. All Rights Reserved.
IMPORTANT: This Apple software is supplied to you by Apple
Inc. ("Apple") in consideration of your agreement to the following
terms, and your use, installation, modification or redistribution of
this Apple software constitutes acceptance of these terms. If you do
not agree with these terms, please do not use, install, modify or
redistribute this Apple software.
In consideration of your agreement to abide by the following terms, and
subject to these terms, Apple grants you a personal, non-exclusive
license, under Apple's copyrights in this original Apple software (the
"Apple Software"), to use, reproduce, modify and redistribute the Apple
Software, with or without modifications, in source and/or binary forms;
provided that if you redistribute the Apple Software in its entirety and
without modifications, you must retain this notice and the following
text and disclaimers in all such redistributions of the Apple Software.
Neither the name, trademarks, service marks or logos of Apple Inc. may
be used to endorse or promote products derived from the Apple Software
without specific prior written permission from Apple. Except as
expressly stated in this notice, no other rights or licenses, express or
implied, are granted by Apple herein, including but not limited to any
patent rights that may be infringed by your derivative works or by other
works in which the Apple Software may be incorporated.
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
-------------------------------------------------------------------------------
SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
The Cut Cross Entropy software includes a number of subcomponents with separate
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
-------------------------------------------------------------------------------

View File

@@ -0,0 +1,10 @@
# Cut Cross Entropy
### Usage
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```

View File

@@ -0,0 +1,83 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)
class CutCrossEntropyPlugin(BasePlugin):
"""
Plugin for Cut Cross Entropy integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs"
def _check_requirements(self):
"""Check if all requirements are met."""
# Check PyTorch version
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
raise ImportError(
"Cut Cross Entropy requires PyTorch >= 2.4.0. "
f"Current version: {torch.__version__}"
)
# Check if cut_cross_entropy is installed
cce_spec = importlib.util.find_spec("cut_cross_entropy")
if cce_spec is None:
raise ImportError(_CCE_INSTALL_MESSAGE)
cce_spec_transformers = importlib.util.find_spec(
"cut_cross_entropy.transformers"
)
if cce_spec_transformers is None:
raise ImportError(_CCE_INSTALL_MESSAGE)
def pre_model_load(self, cfg):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
from cut_cross_entropy.transformers import cce_patch
with zero_only():
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -0,0 +1,42 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
class CutCrossEntropyArgs(BaseModel):
"""
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_dtype_is_half(cls, data):
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
raise ValueError(
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
"Please set `bf16` or `fp16` to `True`."
)
return data

View File

View File

@@ -4,7 +4,6 @@
import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
)
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)
def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
def patch_llama_rms_norm():
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
patch_llama_cross_entropy()
patch_fa_llama_cross_entropy()
# skip only if explicitly disabled
if rms_norm:

View File

@@ -46,9 +46,10 @@ def reset_optimizer(
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
optimizer_magnitude_pruning: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
n_zeros = 0
n_total = 0
@@ -56,16 +57,22 @@ def reset_optimizer(
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
continue
if torch.is_tensor(value):
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
@@ -129,6 +136,9 @@ class ReLoRACallback(TrainerCallback):
if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
if "8bit" in args.optim.lower():
optimizer_state_keys.append("state1")
optimizer_state_keys.append("state2")
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
@@ -160,7 +170,7 @@ class ReLoRACallback(TrainerCallback):
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
optimizer_magnitude_pruning=args.relora_prune_ratio,
)
if self.quantized:

View File

@@ -0,0 +1,207 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -9,10 +9,7 @@ import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
LOG = get_logger("axolotl.monkeypatch.unsloth")
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
def detab_code(code: str) -> Tuple[str, str]:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
@@ -188,7 +191,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
for module in layer_modules
)
mlp_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -213,7 +216,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
qkv_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -232,7 +235,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
o_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)

View File

@@ -259,11 +259,31 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
else:
model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -1,7 +1,11 @@
"""
Basic utils for Axolotl
"""
import importlib.util
import re
import torch
def is_mlflow_available():
@@ -10,3 +14,23 @@ def is_mlflow_available():
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None
# pylint: disable=duplicate-code
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).
"""
torch_version = torch.__version__
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if not version_match:
raise ValueError("Invalid version format")
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present
return major, minor, patch
# pylint: enable=duplicate-code

View File

@@ -1,9 +1,23 @@
"""Benchmarking and measurement utilities"""
import functools
import pynvml
import torch
from pynvml.nvml import NVMLError
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import get_device_type
try:
from pynvml import (
NVMLError,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlInit,
)
except ImportError:
NVMLError = None
nvmlDeviceGetHandleByIndex = None
nvmlDeviceGetMemoryInfo = None
nvmlInit = None
def check_cuda_device(default_value):
@@ -53,24 +67,35 @@ def mps_memory_usage_all():
return usage, reserved - usage, 0
def npu_memory_usage_all(device=0):
usage = torch.npu.memory_allocated(device) / 1024.0**3
reserved = torch.npu.memory_reserved(device) / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
if not nvmlInit:
return 0.0
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(device)
info = nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
except NVMLError:
return 0.0
def log_gpu_memory_usage(log, msg, device):
cur_device = get_device_type()
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in str(cur_device) and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
@@ -79,6 +104,7 @@ def log_gpu_memory_usage(log, msg, device):
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
stacklevel=2,
)
return usage, cache, misc

View File

@@ -28,6 +28,7 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
@@ -46,6 +47,7 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
@@ -64,7 +66,10 @@ class EvalFirstStepCallback(
control: TrainerControl,
**kwargs,
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
control.should_evaluate = True
return control
@@ -375,7 +380,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
for metric in self.cfg.eval_causal_lm_metrics:
if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
metrics[metric] = Perplexity(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
else:
try:
metrics[metric] = evaluate.load(metric)
@@ -392,8 +400,11 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
trainer.model_wrapped.eval()
device = torch.device(
self.cfg.device
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
@@ -430,6 +441,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs
}
if isinstance(metric, Perplexity):
metric_kwargs["model"] = trainer.model_wrapped
metric_score = metric.compute(**metric_kwargs)
return (
metric_score["score"]
@@ -465,89 +480,97 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
with unwrap_model_for_generation(
trainer.model_wrapped, trainer.accelerator
) as unwrapped_model:
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
pos_ranges = find_ranges(pos_ids)
batch_pos_ids = [None] * len(batch["input_ids"])
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
prompt_token_ids_list = []
completion_token_ids_list = []
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = (
input_ids != tokenizer.pad_token_id
)
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(device)
predictions = unwrapped_model.generate(
**prompt_encoding, generation_config=generation_config
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
del prompt_encoding
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list,
skip_special_tokens=True,
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control

View File

@@ -8,6 +8,8 @@ from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.utils.distributed import is_main_process
class Perplexity:
"""
@@ -17,16 +19,13 @@ class Perplexity:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_seq_len: int,
stride: int = 512,
) -> None:
self.max_seq_len = max_seq_len
self.stride = stride
self.model = model
self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity"
def _feature_names(self) -> List[str]:
@@ -34,6 +33,7 @@ class Perplexity:
def compute(
self,
model: PreTrainedModel,
references: Optional[List[str]] = None,
) -> Dict[str, float]:
"""
@@ -41,17 +41,21 @@ class Perplexity:
"""
assert references is not None, "Missing parameter: references"
model.eval()
references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True
)
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(self.device)
input_ids = input_ids.to(model.device)
sequence_length = input_ids.size(1)
losses = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
for begin_loc in tqdm(
range(0, sequence_length, self.stride), disable=not is_main_process()
):
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc]
@@ -59,7 +63,7 @@ class Perplexity:
labels_slice[:, :-trg_len] = -100
with torch.no_grad():
outputs: CausalLMOutput = self.model(
outputs: CausalLMOutput = model(
input_ids=input_ids_slice, labels=labels_slice
)

View File

@@ -1,8 +1,10 @@
"""
Collators for multi-modal chat messages and packing
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
@@ -30,8 +32,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
@@ -46,6 +48,120 @@ class MultiModalChatDataCollator(DataCollatorMixin):
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
def _preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
def _process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
# Preprocess the examples
examples = _preprocess(examples)
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
@@ -53,15 +169,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
)
for example in examples
]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
images = _process_images(examples, max_images=max_images)
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

View File

@@ -5,7 +5,9 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
@@ -29,7 +31,10 @@ def choose_device(cfg):
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"
raise SystemError("No CUDA/mps/npu device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
@@ -39,6 +44,8 @@ def choose_device(cfg):
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}
@@ -146,7 +153,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
and model_config.model_type in ["llama", "mllama_text_model"]
)
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
@@ -223,7 +230,11 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -233,14 +244,35 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
if capabilities:
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)
return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
**cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
).model_dump(exclude_none=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)
def prepare_plugins(cfg):
"""
Prepare the plugins for the configuration
"""
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)

View File

@@ -7,9 +7,9 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from packaging import version
from pydantic import (
BaseModel,
Field,
@@ -20,8 +20,9 @@ from pydantic import (
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -322,11 +323,13 @@ class LoraConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
if not data.get("adapter") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@@ -430,6 +433,8 @@ class HyperparametersConfig(BaseModel):
group_by_length: Optional[bool] = None
learning_rate: Union[str, float]
embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[
@@ -622,6 +627,7 @@ class AxolotlInputConfig(
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None
@@ -1314,6 +1320,7 @@ class AxolotlInputConfig(
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
and data.get("deepspeed", "") is not None
and "zero3" in data.get("deepspeed", "")
):
# may result in:
@@ -1425,21 +1432,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
@@ -1449,11 +1441,46 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
raise NotImplementedError(
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
)
# check quant config
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
optimizer = data.get("optimizer")
raise NotImplementedError(
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
)
quant_list = ["load_in_8bit", "load_in_4bit"]
for quant in quant_list:
if data.get(quant):
raise NotImplementedError(
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
)
# check dtype config
if data.get("tf32"):
raise NotImplementedError(
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities
@model_validator(mode="after")
def check_bf16(self):
@@ -1494,19 +1521,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_hopper_8bit_lora(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):
@@ -1528,3 +1542,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data

View File

@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)
class EnvCapabilities(BaseModel):
"""model to manage the environment capabilities statically"""
torch_version: Optional[str] = Field(default=None)

View File

@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
@@ -208,4 +208,9 @@ def load_prepare_dpo_datasets(cfg):
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
)
return train_dataset, eval_dataset

View File

@@ -2,11 +2,9 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -44,7 +42,11 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
@@ -55,27 +57,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
@@ -136,8 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
return train_dataset, eval_dataset, cfg.max_steps, prompters
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
@@ -584,7 +566,8 @@ def load_prepare_datasets(
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
if cfg.dataset_exact_deduplication:
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
@@ -596,12 +579,17 @@ def load_prepare_datasets(
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
if cfg.dataset_exact_deduplication:
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
else:
eval_dataset = dataset
train_dataset = None
eval_dataset = dataset
else:
train_dataset = dataset
if cfg.dataset_exact_deduplication:
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset, prompters

View File

@@ -1,6 +1,55 @@
"""data handling helpers"""
import functools
import hashlib
import logging
import time
from enum import Enum
import huggingface_hub
import requests
from datasets import Dataset
LOG = logging.getLogger("axolotl")
class RetryStrategy(Enum):
"""
Enum for retry strategies.
"""
CONSTANT = 1
LINEAR = 2
EXPONENTIAL = 3
def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
if retry_strategy == RetryStrategy.EXPONENTIAL:
step_delay = delay * 2**attempt
elif retry_strategy == RetryStrategy.LINEAR:
step_delay = delay * (attempt + 1)
else:
step_delay = delay # Use constant delay.
time.sleep(step_delay)
else:
raise exc
return wrapper
return decorator
def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -8,3 +57,96 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
def deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
) -> Dataset:
unique_indices = []
for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx]
unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
return dataset.select(unique_indices)
def deduplicate_and_log_datasets(
*,
train_dataset: Dataset = None,
eval_dataset: Dataset = None,
dataset: Dataset = None,
) -> tuple[Dataset, Dataset, Dataset]:
"""
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
Returns:
tuple: Deduplicated train, eval, and additional datasets.
"""
seen_hashes: dict[str, list[int]] = {}
# Handle cases where datasets are None
if train_dataset is not None:
LOG.info(
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
)
LOG.info(
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
)
else:
LOG.info("Train dataset is None. Skipping deduplication.")
if eval_dataset is not None:
LOG.info(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
)
else:
LOG.info("Eval dataset is None. Skipping deduplication.")
if dataset is not None and (eval_dataset is None and train_dataset is None):
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
)
return train_dataset, eval_dataset, dataset

View File

@@ -9,10 +9,44 @@ from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
from transformers.utils.import_utils import (
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
)
distributed_state = None # pylint: disable=invalid-name
def get_device_type():
device = torch.device("cpu")
if is_torch_cuda_available():
device = torch.device("cuda")
elif is_torch_mps_available():
device = torch.device("mps")
elif is_torch_npu_available():
device = torch.device("npu")
return device
def get_device_count():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.device_count()
if "npu" in str(cur_device):
return torch.npu.device_count()
return 1
def get_current_device():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.current_device()
if "npu" in str(cur_device):
return torch.npu.current_device()
return 0
def is_distributed():
"""
Check if distributed training is initialized.
@@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
).float()
if not is_main_process():
@@ -115,13 +149,14 @@ def broadcast_dict(vals: dict):
if not is_distributed():
return vals
cur_device = get_device_type()
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
data_size = torch.IntTensor([0]).to(cur_device)
dist.broadcast(data_size, 0)
if not is_main_process():
@@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
Returns:
- The computed value (int or float).
"""
cur_device = f"{get_device_type()}:{get_current_device()}"
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
value_scalar, device=cur_device, dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device(), dtype=torch.float32
0.0, device=cur_device, dtype=torch.float32
) # Placeholder tensor
# Broadcast the tensor to all processes.
@@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
).float()
# Placeholder tensor for gathering results

View File

@@ -2,10 +2,12 @@
# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
@@ -55,7 +57,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -384,6 +386,15 @@ class ModelLoader:
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -409,7 +420,7 @@ class ModelLoader:
)
if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -451,27 +462,34 @@ class ModelLoader:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def patch_loss(self) -> None:
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -481,6 +499,7 @@ class ModelLoader:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()
if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -528,16 +547,6 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention."
)
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
@@ -570,7 +579,8 @@ class ModelLoader:
)
max_memory = {}
for i in range(torch.cuda.device_count()):
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
@@ -595,8 +605,11 @@ class ModelLoader:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
if torch.backends.mps.is_available():
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
@@ -1050,7 +1063,11 @@ class ModelLoader:
self.ajust_model_config()
# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -1076,14 +1093,17 @@ class ModelLoader:
self.prepare_model(qlora_fsdp)
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp:
LOG.info(
"converting modules to %s for flash attention", self.cfg.torch_dtype
)
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)
if should_convert:
LOG.info("Converting modules to %s", self.cfg.torch_dtype)
self.convert_embedding_modules_dtype(
embedding_modules,
embedding_modules=embedding_modules,
dist_dtype=self.cfg.torch_dtype,
before_kbit_train_or_finetune=False,
)
@@ -1118,9 +1138,9 @@ class ModelLoader:
and not skip_move_to_device
):
# TODO revaldate this conditional
self.model.to(f"cuda:{self.cfg.local_rank}")
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
setattr(self.model, "model_parallel", True)

View File

@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
"""
# mypy: ignore-errors
# pylint: skip-file
# flake8: noqa
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast
from typing import Callable, List, Optional, Tuple, Union, cast
import torch
from torch import Tensor
from torch.optim.optimizer import (
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
DeviceDict,
Optimizer,
ParamsT,
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
)
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
weight_decay: float = 0.0,
decoupled: bool = False,
decouple: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.clip_lambda = clip_lambda
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
decoupled=decoupled,
decouple=decouple,
maximize=maximize,
foreach=foreach,
capturable=capturable,
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
beta1=beta1,
beta2=beta2,
lr=group["lr"],
clip_lambda=self.clip_lambda,
weight_decay=group["weight_decay"],
decoupled=group["decoupled"],
decouple=group["decouple"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
capturable: bool,
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
step = step_t if capturable or differentiable else _get_value(step_t)
if weight_decay != 0:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if weight_decay != 0 and not decouple:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)
step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
if step == 0:
exp_avg_sq.addcmul_(grad, grad.conj())
# update step
step_t += 1
continue
if weight_decay != 0 and decouple:
param.add_(param, alpha=-lr * weight_decay)
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2:
exp_avg.addcdiv_(grad, denom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
normed_grad = grad.div(denom)
if clip_lambda is not None:
clip = clip_lambda(step)
normed_grad.clamp_(-clip, clip)
exp_avg.lerp_(normed_grad, 1 - beta1)
param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
# update step
step_t += 1
def _multi_tensor_adopt(
params: List[Tensor],
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
capturable: bool,
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0 and not decouple:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 0:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
continue
if weight_decay != 0 and decouple:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if clip_lambda is not None:
clip = clip_lambda(device_state_steps[0])
torch._foreach_maximum_(normed_grad, -clip)
torch._foreach_minimum_(normed_grad, clip)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt(
@@ -443,8 +472,9 @@ def adopt(
beta1: float,
beta2: float,
lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float,
decoupled: bool,
decouple: bool,
eps: float,
maximize: bool,
):
@@ -497,8 +527,9 @@ def adopt(
beta1=beta1,
beta2=beta2,
lr=lr,
clip_lambda=clip_lambda,
weight_decay=weight_decay,
decoupled=decoupled,
decouple=decouple,
eps=eps,
maximize=maximize,
capturable=capturable,

0
tests/cli/__init__.py Normal file
View File

36
tests/cli/conftest.py Normal file
View File

@@ -0,0 +1,36 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner
VALID_TEST_CONFIG = """
base_model: HuggingFaceTB/SmolLM2-135M
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
sequence_len: 2048
max_steps: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1e-3
special_tokens:
pad_token: <|endoftext|>
"""
@pytest.fixture
def cli_runner():
return CliRunner()
@pytest.fixture
def valid_test_config():
return VALID_TEST_CONFIG
@pytest.fixture
def config_path(tmp_path):
"""Creates a temporary config file"""
path = tmp_path / "config.yml"
path.write_text(VALID_TEST_CONFIG)
return path

View File

@@ -0,0 +1,38 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch
def test_fetch_cli_examples(cli_runner):
"""Test fetch command with examples directory"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
result = cli_runner.invoke(fetch, ["examples"])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("examples/", None)
def test_fetch_cli_deepspeed(cli_runner):
"""Test fetch command with deepspeed_configs directory"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
result = cli_runner.invoke(fetch, ["deepspeed_configs"])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("deepspeed_configs/", None)
def test_fetch_cli_with_dest(cli_runner, tmp_path):
"""Test fetch command with custom destination"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
custom_dir = tmp_path / "tmp_examples"
result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("examples/", str(custom_dir))
def test_fetch_cli_invalid_directory(cli_runner):
"""Test fetch command with invalid directory choice"""
result = cli_runner.invoke(fetch, ["invalid"])
assert result.exit_code != 0

View File

@@ -0,0 +1,30 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli
def test_inference_basic(cli_runner, config_path):
"""Test basic inference"""
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_gradio(cli_runner, config_path):
"""Test basic inference (gradio path)"""
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0

View File

@@ -0,0 +1,47 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli
def test_build_command():
"""Test converting dict of options to CLI arguments"""
base_cmd = ["accelerate", "launch"]
options = {
"learning_rate": 1e-4,
"batch_size": 8,
"debug": True,
"use_fp16": False,
"null_value": None,
}
result = build_command(base_cmd, options)
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
]
def test_invalid_command_options(cli_runner):
"""Test handling of invalid command options"""
result = cli_runner.invoke(
cli,
[
"train",
"config.yml",
"--invalid-option",
"value",
],
)
assert result.exit_code != 0
assert "No such option" in result.output
def test_required_config_argument(cli_runner):
"""Test commands fail properly when config argument is missing"""
result = cli_runner.invoke(cli, ["train"])
assert result.exit_code != 0
assert "Missing argument 'CONFIG'" in result.output

View File

@@ -0,0 +1,56 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli
def test_merge_lora_basic(cli_runner, config_path):
"""Test basic merge_lora command"""
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):
"""Test merge_lora with custom lora and output directories"""
lora_dir = tmp_path / "lora"
output_dir = tmp_path / "output"
lora_dir.mkdir()
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli,
[
"merge-lora",
str(config_path),
"--lora-model-dir",
str(lora_dir),
"--output-dir",
str(output_dir),
],
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir)
assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir)
def test_merge_lora_nonexistent_config(cli_runner, tmp_path):
"""Test merge_lora with nonexistent config"""
config_path = tmp_path / "nonexistent.yml"
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
assert result.exit_code != 0
def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):
"""Test merge_lora with nonexistent lora directory"""
lora_dir = tmp_path / "nonexistent"
result = cli_runner.invoke(
cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)]
)
assert result.exit_code != 0

View File

@@ -0,0 +1,60 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -0,0 +1,71 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch
import pytest
from axolotl.cli.main import cli
@pytest.fixture(autouse=True)
def cleanup_last_run_prepared():
yield
if Path("last_run_prepared").exists():
shutil.rmtree("last_run_prepared")
def test_preprocess_config_not_found(cli_runner):
"""Test preprocess fails when config not found"""
result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"])
assert result.exit_code != 0
def test_preprocess_basic(cli_runner, config_path):
"""Test basic preprocessing with minimal config"""
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["download"] is True
def test_preprocess_without_download(cli_runner, config_path):
"""Test preprocessing without model download"""
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli, ["preprocess", str(config_path), "--no-download"]
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["download"] is False
def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
"""Test preprocessing with custom dataset path"""
config_path = tmp_path / "config.yml"
custom_path = tmp_path / "custom_prepared"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli,
[
"preprocess",
str(config_path),
"--dataset-prepared-path",
str(custom_path.absolute()),
],
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
custom_path.absolute()
)

View File

@@ -0,0 +1,76 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -0,0 +1,98 @@
"""pytest tests for axolotl CLI train command."""
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
def test_train_cli_validation(cli_runner):
"""Test CLI validation"""
# Test missing config file
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["train", str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.train",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
"""Test CLI arguments properly override config values"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"
test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2

72
tests/cli/test_utils.py Normal file
View File

@@ -0,0 +1,72 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch
import click
import pytest
import requests
from axolotl.cli.utils import fetch_from_github
# Sample GitHub API response
MOCK_TREE_RESPONSE = {
"tree": [
{"path": "examples/config1.yml", "type": "blob", "sha": "abc123"},
{"path": "examples/config2.yml", "type": "blob", "sha": "def456"},
{"path": "other/file.txt", "type": "blob", "sha": "xyz789"},
]
}
@pytest.fixture
def mock_responses():
"""Mock responses for API and file downloads"""
def mock_get(url, timeout=None): # pylint: disable=unused-argument
response = Mock()
if "api.github.com" in url:
response.text = json.dumps(MOCK_TREE_RESPONSE)
else:
response.content = b"file content"
return response
return mock_get
def test_fetch_from_github_new_files(tmp_path, mock_responses):
"""Test fetching new files"""
with patch("requests.get", mock_responses):
fetch_from_github("examples/", tmp_path)
# Verify files were created
assert (tmp_path / "config1.yml").exists()
assert (tmp_path / "config2.yml").exists()
assert not (tmp_path / "file.txt").exists()
def test_fetch_from_github_unchanged_files(tmp_path, mock_responses):
"""Test handling of unchanged files"""
# Create existing file with matching SHA
existing_file = tmp_path / "config1.yml"
existing_file.write_bytes(b"file content")
with patch("requests.get", mock_responses):
fetch_from_github("examples/", tmp_path)
# File should not be downloaded again
assert existing_file.read_bytes() == b"file content"
def test_fetch_from_github_invalid_prefix(mock_responses):
"""Test error handling for invalid directory prefix"""
with patch("requests.get", mock_responses):
with pytest.raises(click.ClickException):
fetch_from_github("nonexistent/", None)
def test_fetch_from_github_network_error():
"""Test handling of network errors"""
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)

144
tests/conftest.py Normal file
View File

@@ -0,0 +1,144 @@
"""
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time
import pytest
import requests
from huggingface_hub import snapshot_download
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
# download the model
snapshot_download_w_retry("JackFram/llama-68m")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset
snapshot_download_w_retry(
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
)
@pytest.fixture(scope="session", autouse=True)
def download_mlabonne_finetome_100k_dataset():
# download the dataset
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
)
@pytest.fixture
def temp_dir():
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

32
tests/constants.py Normal file
View File

@@ -0,0 +1,32 @@
# constants.py
"""
This module contains constants and configuration dictionaries used for
datasets and other utilities in the Axolotl project, specifically for testing.
"""
# Configuration for Alpaca Messages Dataset
ALPACA_MESSAGES_CONFIG_OG = {
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
# Revision configuration extending the original
ALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy()
ALPACA_MESSAGES_CONFIG_REVISION["revision"] = "ea82cff"
SPECIAL_TOKENS = {
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}

View File

@@ -14,9 +14,7 @@ from axolotl.utils.models import load_model, load_tokenizer
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00005,
@@ -33,6 +31,9 @@ def fixture_cfg():
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"model_config_type": "llama",
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)

View File

@@ -1,35 +0,0 @@
"""
shared pytest fixtures
"""
import shutil
import tempfile
import pytest
from huggingface_hub import snapshot_download
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download("HuggingFaceTB/SmolLM2-135M")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the model
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset():
# download the model
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture
def temp_dir():
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)

View File

@@ -7,7 +7,7 @@ from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
@@ -54,8 +54,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -99,8 +101,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -0,0 +1,94 @@
"""
Simple end-to-end test for Cut Cross Entropy integration
"""
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
@pytest.fixture()
def min_cfg(temp_dir):
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
class TestCutCrossEntropyIntegration:
"""
e2e tests for cut_cross_entropy integration with Axolotl
"""
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@pytest.mark.parametrize(
"attention_type",
["flash_attention", "sdp_attention", "xformers_attention"],
)
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault(
min_cfg
| {
attention_type: True,
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -11,6 +11,8 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -26,7 +28,7 @@ class TestMultiGPUEval:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
@@ -40,8 +42,8 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"val_set_size": 0.004,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
@@ -66,6 +68,7 @@ class TestMultiGPUEval:
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
}
)
@@ -88,11 +91,13 @@ class TestMultiGPUEval:
]
)
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
@@ -106,8 +111,8 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"val_set_size": 0.0004,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
@@ -132,6 +137,7 @@ class TestMultiGPUEval:
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
}
)
@@ -153,3 +159,5 @@ class TestMultiGPUEval:
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.9, "Eval Loss is too high")

View File

@@ -14,8 +14,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import is_hopper
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -144,7 +142,6 @@ class TestMultiGPULlama:
]
)
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(

View File

@@ -42,7 +42,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.02,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
@@ -86,7 +86,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.02,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",

View File

@@ -0,0 +1,47 @@
"""
test cases to make sure the plugin args are loaded from the config file
"""
from pathlib import Path
import yaml
from axolotl.cli import load_cfg
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
class TestPluginArgs:
"""
test class for plugin args loaded from the config file
"""
def test_liger_plugin_args(self, temp_dir):
test_cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
"liger_layer_norm": True,
"liger_rope": True,
"liger_rms_norm": False,
"liger_glu_activation": True,
"liger_fused_linear_cross_entropy": True,
}
)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(test_cfg.to_dict()))
cfg = load_cfg(str(Path(temp_dir) / "config.yaml"))
assert cfg.liger_layer_norm is True
assert cfg.liger_rope is True
assert cfg.liger_rms_norm is False
assert cfg.liger_glu_activation is True
assert cfg.liger_fused_linear_cross_entropy is True

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging
import os
import unittest
from importlib import reload
from pathlib import Path
@@ -17,7 +16,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -31,49 +30,55 @@ def reload_transformers():
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama(unittest.TestCase):
class TestFAXentropyLlama:
"""
Test case for Llama models using LoRA w multipack
"""
@with_temp_dir
def test_lora_packing_fa_cross_entropy(self, temp_dir):
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"flash_attn_cross_entropy": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"path": "mlabonne/FineTome-100k",
"field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
},
],
"num_epochs": 1,
"max_steps": 10,
"save_steps": 10,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"max_steps": 5,
"save_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
@@ -87,3 +92,7 @@ class TestFAXentropyLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
)

View File

@@ -40,7 +40,7 @@ class TestFalconPatched(unittest.TestCase):
"lora_dropout": 0.1,
"lora_target_linear": True,
"lora_modules_to_save": ["word_embeddings", "lm_head"],
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
@@ -80,7 +80,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",

View File

@@ -38,7 +38,7 @@ class TestFusedLlama(unittest.TestCase):
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -98,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -39,7 +39,7 @@ class TestMistral(unittest.TestCase):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
@@ -80,7 +80,7 @@ class TestMistral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -40,7 +40,7 @@ class TestMixtral(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {},
"datasets": [
{
@@ -78,7 +78,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {},
"datasets": [
{

View File

@@ -38,7 +38,7 @@ class TestPhiMultipack(unittest.TestCase):
"pad_to_sequence_len": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -6,7 +6,6 @@ import logging
import os
import re
import subprocess
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
@@ -17,35 +16,35 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import most_recent_subdir, with_temp_dir
from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestResumeLlama(unittest.TestCase):
class TestResumeLlama:
"""
Test case for resuming training of llama models
"""
@with_temp_dir
def test_resume_qlora_packed(self, temp_dir):
def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"val_set_size": 0.001,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
@@ -57,11 +56,11 @@ class TestResumeLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"save_steps": 10,
"save_steps": 3,
"save_total_limit": 5,
"max_steps": 40,
"max_steps": 15,
"use_tensorboard": True,
}
)
@@ -77,7 +76,7 @@ class TestResumeLlama(unittest.TestCase):
resume_cfg = cfg | DictDefault(
{
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
}
)
normalize_config(resume_cfg)
@@ -93,4 +92,4 @@ class TestResumeLlama(unittest.TestCase):
)
pattern = r"first_step\s+(\d+)"
first_steps = int(re.findall(pattern, res.stdout)[0])
assert first_steps == 31
assert first_steps == 10

View File

@@ -0,0 +1,186 @@
"""
e2e tests for unsloth qlora
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sdp_attention",
[True, False],
)
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"sdp_attention": sdp_attention,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -0,0 +1,113 @@
"""
E2E tests for llama pretrain
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestEmbeddingsLrScale(unittest.TestCase):
"""
Test case for embedding_lr*
"""
@with_temp_dir
def test_train_w_embedding_lr_scale(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"embedding_lr_scale": 0.5,
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"embedding_lr": 0.000005,
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)

View File

@@ -0,0 +1,116 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLlamaVision(unittest.TestCase):
"""
Test case for Llama Vision models
"""
@with_temp_dir
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "axolotl-ai-co/llava-instruct-mix-vsft-small",
"type": "chat_template",
"split": "train",
"field_messages": "messages",
},
],
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -57,6 +57,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
}
)
normalize_config(cfg)

View File

@@ -56,6 +56,7 @@ class TestCustomOptimizers(unittest.TestCase):
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
@@ -94,6 +95,7 @@ class TestCustomOptimizers(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
@@ -115,7 +117,7 @@ class TestCustomOptimizers(unittest.TestCase):
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
@@ -126,13 +128,14 @@ class TestCustomOptimizers(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "schedule_free_adamw",
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
}
)
# pylint: disable=duplicate-code

View File

@@ -6,7 +6,6 @@ import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
@@ -15,7 +14,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -66,9 +65,6 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -29,35 +29,48 @@ class TestReLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 25,
"relora_warmup_steps": 5,
"relora_anneal_steps": 5,
"relora_steps": 100,
"relora_warmup_steps": 20,
"relora_anneal_steps": 10,
"relora_prune_ratio": 0.9,
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {},
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"warmup_steps": 15,
"warmup_steps": 20,
"num_epochs": 2,
"micro_batch_size": 4,
"max_steps": 205, # at least 2x relora_steps
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
}
)
normalize_config(cfg)
@@ -65,4 +78,11 @@ class TestReLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
)

View File

@@ -12,6 +12,7 @@ import torch
# from importlib.metadata import version
from packaging import version
from tbparse import SummaryReader
def with_temp_dir(test_func):
@@ -53,7 +54,7 @@ def require_torch_2_3_1(test_case):
def require_torch_2_5_1(test_case):
"""
Decorator marking a test that requires torch >= 2.3.1
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_5_1():
@@ -66,3 +67,17 @@ def require_torch_2_5_1(test_case):
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None:
"""
helper function to parse and check tensorboard logs
"""
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err

Some files were not shown because too many files have changed in this diff Show More