Compare commits
340 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81d60e96f0 | ||
|
|
168a7a09cc | ||
|
|
231031a0e1 | ||
|
|
5daf7d5299 | ||
|
|
5491278a79 | ||
|
|
1514739f0f | ||
|
|
896c1aebcf | ||
|
|
ef17e15483 | ||
|
|
69a235061b | ||
|
|
687d889928 | ||
|
|
c4cf567b55 | ||
|
|
c49729d2bc | ||
|
|
13ac4d8de2 | ||
|
|
19cf0bda99 | ||
|
|
f74edd5b56 | ||
|
|
d69da99c2c | ||
|
|
66afb76a15 | ||
|
|
a692ad3f4c | ||
|
|
41da98b982 | ||
|
|
9e64f42e0f | ||
|
|
b9b7d4ce92 | ||
|
|
9bed281867 | ||
|
|
e79c8e617e | ||
|
|
71456955f5 | ||
|
|
3a783c04e4 | ||
|
|
1e5014acec | ||
|
|
a10da1caff | ||
|
|
4066c78631 | ||
|
|
78a1e1fa12 | ||
|
|
bc8a2e5547 | ||
|
|
910ebe47f5 | ||
|
|
c146880a75 | ||
|
|
77bdb7d144 | ||
|
|
530809fd74 | ||
|
|
924bbfddec | ||
|
|
f150c027e3 | ||
|
|
5c39c006c9 | ||
|
|
612aabd8c4 | ||
|
|
af05883f75 | ||
|
|
05ab9092e3 | ||
|
|
7b57ed7618 | ||
|
|
3a38271276 | ||
|
|
8d20e0a3d3 | ||
|
|
de8ed229c3 | ||
|
|
478d8c7b8e | ||
|
|
645c13592c | ||
|
|
47d601fa23 | ||
|
|
756dfba97b | ||
|
|
91ab0592af | ||
|
|
0aeb7c7802 | ||
|
|
9bdd30cdfd | ||
|
|
d35278aaf1 | ||
|
|
9492d4ebb7 | ||
|
|
ad5ca4f734 | ||
|
|
cb9d3af5c0 | ||
|
|
c969f0a9dc | ||
|
|
6d0ee4ba34 | ||
|
|
a81f52d575 | ||
|
|
1925eaf1e6 | ||
|
|
1ab3bf3e67 | ||
|
|
d7635b7148 | ||
|
|
88e17ffc50 | ||
|
|
baed440fa1 | ||
|
|
7925ddce86 | ||
|
|
6f849809c5 | ||
|
|
c16644d05e | ||
|
|
945c4191a3 | ||
|
|
136522f9c9 | ||
|
|
556fe408b3 | ||
|
|
16bb6276a5 | ||
|
|
06674a11f2 | ||
|
|
3513885f43 | ||
|
|
06652c1c39 | ||
|
|
068fc48978 | ||
|
|
aaadacf6b3 | ||
|
|
5ff547dc70 | ||
|
|
dc77c8ebce | ||
|
|
51a4c12242 | ||
|
|
4b43a66a0b | ||
|
|
34ae69989f | ||
|
|
7dc580b837 | ||
|
|
fd2c9814c9 | ||
|
|
2ba4ae8f46 | ||
|
|
93dacba228 | ||
|
|
8002ffb41f | ||
|
|
74ef5cc083 | ||
|
|
5e616d91c0 | ||
|
|
94f310c7a6 | ||
|
|
8e568bbdae | ||
|
|
e21dab49fd | ||
|
|
52cde69288 | ||
|
|
9a58e99e81 | ||
|
|
c7dee56b87 | ||
|
|
aac4b7691e | ||
|
|
f31a338cbb | ||
|
|
4cd1deeef2 | ||
|
|
9ac16ed8d1 | ||
|
|
6b3f509d9e | ||
|
|
336aa3fd48 | ||
|
|
d0d7eaa4f3 | ||
|
|
a6ebf57e82 | ||
|
|
280832cec2 | ||
|
|
a43bae9ff0 | ||
|
|
effbbf6dd1 | ||
|
|
c9a149f9e8 | ||
|
|
c530e4b9c8 | ||
|
|
f620706776 | ||
|
|
77762a5d6b | ||
|
|
14668fa54e | ||
|
|
b565ecf0a1 | ||
|
|
fe0b76854e | ||
|
|
e944311442 | ||
|
|
e3e7b52a5b | ||
|
|
974dc00a7d | ||
|
|
572d1141e6 | ||
|
|
a6190c8094 | ||
|
|
563b6d89e6 | ||
|
|
cd0a6f6027 | ||
|
|
0e664a5ebc | ||
|
|
dd7d16d2eb | ||
|
|
e285e24f7f | ||
|
|
919727b4d7 | ||
|
|
5ffefee37f | ||
|
|
d9f713e4e3 | ||
|
|
958da70376 | ||
|
|
c4e4f8115c | ||
|
|
a808bf913f | ||
|
|
01248253a3 | ||
|
|
759e8673ce | ||
|
|
0c6f928601 | ||
|
|
eea2731a5e | ||
|
|
1db46a9c72 | ||
|
|
ab5cd28acf | ||
|
|
1a82082e91 | ||
|
|
1210dc8fd5 | ||
|
|
488a67d75a | ||
|
|
71a43f8479 | ||
|
|
39619028a3 | ||
|
|
8792199799 | ||
|
|
1edc30c786 | ||
|
|
14163c15d9 | ||
|
|
41e4f6ca31 | ||
|
|
79e2a6f140 | ||
|
|
c2508987a6 | ||
|
|
215d775147 | ||
|
|
f36e227eaf | ||
|
|
5878bb1f3a | ||
|
|
a03a7d7d8b | ||
|
|
fec6bcc3e6 | ||
|
|
931e606459 | ||
|
|
7f09106437 | ||
|
|
6b50200234 | ||
|
|
16f9e28048 | ||
|
|
b9083a7fc1 | ||
|
|
aefb2fc681 | ||
|
|
b5aa8d854c | ||
|
|
4d6490bce2 | ||
|
|
b242b69e10 | ||
|
|
320beb20f4 | ||
|
|
bd3b537344 | ||
|
|
813cfa4c14 | ||
|
|
2e13ceff37 | ||
|
|
2a801b001a | ||
|
|
e44c9e0b3e | ||
|
|
55b8542de8 | ||
|
|
febe902517 | ||
|
|
f4df266842 | ||
|
|
281dc3df59 | ||
|
|
2ef4634d45 | ||
|
|
7eae90333e | ||
|
|
c8242de725 | ||
|
|
2cfe9e9b16 | ||
|
|
79a8f52181 | ||
|
|
afaa0d2c01 | ||
|
|
bfd27ba55e | ||
|
|
babf0fdb71 | ||
|
|
a52f4816b0 | ||
|
|
81911d112c | ||
|
|
52765ac588 | ||
|
|
73e9ea4069 | ||
|
|
f8d379883d | ||
|
|
04a1b77307 | ||
|
|
2097a09d2d | ||
|
|
cfff94b123 | ||
|
|
2b222de5b6 | ||
|
|
df9528f865 | ||
|
|
193c73bce0 | ||
|
|
6abfd87d44 | ||
|
|
59bb2197ed | ||
|
|
9a02e7e1ff | ||
|
|
5b33e295bd | ||
|
|
4ac9e251b7 | ||
|
|
c9c050316f | ||
|
|
ca11ae9689 | ||
|
|
328c3bce96 | ||
|
|
5cd2126439 | ||
|
|
12620f3089 | ||
|
|
4ab0c8b201 | ||
|
|
74ebbf4371 | ||
|
|
76a70fd739 | ||
|
|
618816d4df | ||
|
|
91992cb8f5 | ||
|
|
84169d15b3 | ||
|
|
ecfe8d0a1a | ||
|
|
eee44a3b47 | ||
|
|
078a43eef8 | ||
|
|
33e1890086 | ||
|
|
1c38253692 | ||
|
|
496b83f778 | ||
|
|
ff68a95781 | ||
|
|
fb3d40f197 | ||
|
|
288fd62431 | ||
|
|
3c71c8debe | ||
|
|
a6f5e5eaec | ||
|
|
5a631b305b | ||
|
|
f94dd626f0 | ||
|
|
5079753b7a | ||
|
|
0136f510f2 | ||
|
|
72bf8aafb6 | ||
|
|
8afb0fbaba | ||
|
|
9b8585dc70 | ||
|
|
8eb5811d4e | ||
|
|
e0011fdf55 | ||
|
|
6e9e98720e | ||
|
|
c2a0792680 | ||
|
|
b267d24a2b | ||
|
|
5c3f5db38b | ||
|
|
e3d03745ba | ||
|
|
fac46002d4 | ||
|
|
33d40179ba | ||
|
|
dcb03d6da4 | ||
|
|
0e4be625ae | ||
|
|
bdc4bd7d4e | ||
|
|
2d0ba3b818 | ||
|
|
c7021e191f | ||
|
|
c56818b119 | ||
|
|
2675fb756e | ||
|
|
1076bcbbca | ||
|
|
2daa6835f0 | ||
|
|
e3c494ca7b | ||
|
|
ad0ea6aaab | ||
|
|
876edd83d0 | ||
|
|
6cb2310592 | ||
|
|
6fa40bf8ad | ||
|
|
3aad5f3b3e | ||
|
|
39a208c2bc | ||
|
|
2520ecd6df | ||
|
|
c5b0af1a7e | ||
|
|
988aeb9c34 | ||
|
|
cf61f14bff | ||
|
|
0abcd71a85 | ||
|
|
c43c5c84ff | ||
|
|
36ec6e1a0e | ||
|
|
13b80937f9 | ||
|
|
bbc5bc5791 | ||
|
|
4df9da74e3 | ||
|
|
2531ea24c1 | ||
|
|
01a75fd027 | ||
|
|
b81c97ff76 | ||
|
|
594e72b6e8 | ||
|
|
25eeeeba0b | ||
|
|
cfcc549f6b | ||
|
|
a1f9850b91 | ||
|
|
83d29209f7 | ||
|
|
d011422200 | ||
|
|
b1cc54b14a | ||
|
|
c17dae6d07 | ||
|
|
37293dce07 | ||
|
|
96e8378692 | ||
|
|
e9650d3ae4 | ||
|
|
f1232b35ba | ||
|
|
741a3f2edc | ||
|
|
0dd35c74af | ||
|
|
db288e9b13 | ||
|
|
be22551435 | ||
|
|
b832a0ac62 | ||
|
|
afb31e13a3 | ||
|
|
1bf1f59a41 | ||
|
|
8e46c0fb0d | ||
|
|
1f3c3f5ea0 | ||
|
|
0e952889dc | ||
|
|
9c6750a075 | ||
|
|
c2dbf2c526 | ||
|
|
e6b57decbd | ||
|
|
fe1f4c4e7d | ||
|
|
dae14e5951 | ||
|
|
633ff2150f | ||
|
|
5d86137f70 | ||
|
|
01c8a333b3 | ||
|
|
7eb33a77dd | ||
|
|
1645a4ddd5 | ||
|
|
145b060cbe | ||
|
|
8cc0aadcb8 | ||
|
|
6abb7f6a16 | ||
|
|
de2406c488 | ||
|
|
8b617cc7f6 | ||
|
|
ddb86ea821 | ||
|
|
1a2bd7ff62 | ||
|
|
82971e1565 | ||
|
|
f4e5d86268 | ||
|
|
daf47ccf45 | ||
|
|
545cfeb5c7 | ||
|
|
69722aeef4 | ||
|
|
5658717dbd | ||
|
|
e8717d3bef | ||
|
|
54c3b5b25f | ||
|
|
5062eca069 | ||
|
|
cb4f0e9342 | ||
|
|
4c0eddb3f8 | ||
|
|
1c60c10e00 | ||
|
|
903ea3080d | ||
|
|
cb7cd3429f | ||
|
|
d57ba56746 | ||
|
|
c3a4697016 | ||
|
|
392dfd9b07 | ||
|
|
a98deb31a6 | ||
|
|
36596adaf7 | ||
|
|
6cee881d64 | ||
|
|
48612f8376 | ||
|
|
d91a769b88 | ||
|
|
6ef96f569b | ||
|
|
ac85c0ed36 | ||
|
|
f1fbf666f7 | ||
|
|
370d057096 | ||
|
|
e0ccaccce2 | ||
|
|
15e57ba6ee | ||
|
|
4eb68ac3f7 | ||
|
|
b6a539b53c | ||
|
|
abddcf4dfe | ||
|
|
15aabd2903 | ||
|
|
232b931081 | ||
|
|
0736f4f9c1 | ||
|
|
d77d736631 | ||
|
|
fad06befee | ||
|
|
2aacf75ee1 | ||
|
|
71871345a6 | ||
|
|
0d14e951a8 | ||
|
|
84fc217f79 | ||
|
|
f317296259 | ||
|
|
42a971df32 |
5
.flake8
Normal file
5
.flake8
Normal file
@@ -0,0 +1,5 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
|
||||
select = C,E,F,W,B,B950
|
||||
extend-ignore = E203, E501, W503
|
||||
31
.github/release-drafter.yml
vendored
Normal file
31
.github/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name-template: 'v$RESOLVED_VERSION'
|
||||
tag-template: 'v$RESOLVED_VERSION'
|
||||
categories:
|
||||
- title: '🚀 Features'
|
||||
labels:
|
||||
- 'feature'
|
||||
- 'enhancement'
|
||||
- title: '🐛 Bug Fixes'
|
||||
labels:
|
||||
- 'fix'
|
||||
- 'bugfix'
|
||||
- 'bug'
|
||||
- title: '🧰 Maintenance'
|
||||
label: 'chore'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.
|
||||
version-resolver:
|
||||
major:
|
||||
labels:
|
||||
- 'major'
|
||||
minor:
|
||||
labels:
|
||||
- 'minor'
|
||||
patch:
|
||||
labels:
|
||||
- 'patch'
|
||||
default: patch
|
||||
template: |
|
||||
## What’s Changed
|
||||
|
||||
$CHANGES
|
||||
28
.github/workflows/base.yml
vendored
28
.github/workflows/base.yml
vendored
@@ -12,16 +12,29 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: self-hosted
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
cuda_version_bnb: "118"
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
cuda_version_bnb: "117"
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: "117"
|
||||
cuda_version: 11.7.1
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -43,12 +56,13 @@ jobs:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
CUDA_VERSION_BNB=${{ matrix.cuda_version_bnb }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTHON_VERSION=${{ matrix.python_version }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}
|
||||
|
||||
41
.github/workflows/main.yml
vendored
41
.github/workflows/main.yml
vendored
@@ -11,14 +11,29 @@ jobs:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
cuda_version: 11.7.1
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -40,10 +55,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -56,10 +71,24 @@ jobs:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
cuda_version: 11.7.1
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -81,10 +110,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
file: ./docker/Dockerfile-runpod
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
16
.github/workflows/pre-commit.yml
vendored
Normal file
16
.github/workflows/pre-commit.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -7,6 +7,7 @@ jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.9", "3.10"]
|
||||
timeout-minutes: 10
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -160,4 +160,4 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
.idea/
|
||||
|
||||
2
.isort.cfg
Normal file
2
.isort.cfg
Normal file
@@ -0,0 +1,2 @@
|
||||
[settings]
|
||||
profile=black
|
||||
39
.mypy.ini
Normal file
39
.mypy.ini
Normal file
@@ -0,0 +1,39 @@
|
||||
[mypy]
|
||||
|
||||
exclude = venv
|
||||
|
||||
[mypy-alpaca_lora_4bit.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-flash_attn.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-huggingface_hub]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-transformers.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-peft]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-datasets]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-fire]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-setuptools]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-addict]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-xformers.*]
|
||||
ignore_missing_imports = True
|
||||
42
.pre-commit-config.yaml
Normal file
42
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
default_language_version:
|
||||
python: python3
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
rev: v2.17.4
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
[
|
||||
'types-PyYAML',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.5
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
'--ini',
|
||||
'.bandit',
|
||||
]
|
||||
14
.pylintrc
Normal file
14
.pylintrc
Normal file
@@ -0,0 +1,14 @@
|
||||
[MASTER]
|
||||
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=numpy.*, torch.*
|
||||
|
||||
|
||||
[pylint.messages_control]
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
3
FAQS.md
3
FAQS.md
@@ -2,3 +2,6 @@
|
||||
|
||||
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
|
||||
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
|
||||
- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
|
||||
`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
|
||||
This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
|
||||
|
||||
267
README.md
267
README.md
@@ -9,36 +9,40 @@
|
||||
<p>
|
||||
Go ahead and axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
|
||||
|---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❓ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
**Requirements**: Python 3.9.
|
||||
**Requirements**: Python 3.9 and Pytorch 2.0.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
|
||||
pip3 install -e .[int4]
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
accelerate config
|
||||
|
||||
# finetune lora
|
||||
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
--inference --lora_model_dir="./lora-out"
|
||||
```
|
||||
|
||||
@@ -48,18 +52,83 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
|
||||
- `winglian/axolotl:dev`: dev branch (not usually up to date)
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
```sh
|
||||
docker compose up -d
|
||||
```
|
||||
- `winglian/axolotl:dev`: dev branch
|
||||
- `winglian/axolotl-runpod:main`: for runpod
|
||||
|
||||
- Conda/Pip venv
|
||||
1. Install python **3.9**
|
||||
|
||||
2. Install python dependencies with ONE of the following:
|
||||
- `pip3 install -e .[int4]` (recommended)
|
||||
- `pip3 install -e .[int4_triton]`
|
||||
- `pip3 install -e .`
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||
```bash
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
- gptq/int4 support, NO QLoRA
|
||||
```bash
|
||||
pip3 install -e .[gptq]
|
||||
```
|
||||
- same as above but not recommended
|
||||
```bash
|
||||
pip3 install -e .[gptq_triton]
|
||||
```
|
||||
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
1. Install python
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y python3.9
|
||||
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
||||
sudo update-alternatives --config python # pick 3.9 if given option
|
||||
python -V # should be 3.9
|
||||
|
||||
```
|
||||
|
||||
2. Install pip
|
||||
```bash
|
||||
wget https://bootstrap.pypa.io/get-pip.py
|
||||
python get-pip.py
|
||||
```
|
||||
|
||||
3. Install torch
|
||||
```bash
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
4. Axolotl
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install -e . # change depend on needs
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U requests
|
||||
pip3 install -U --ignore-installed psutil
|
||||
pip3 install -U scipy
|
||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
||||
```
|
||||
|
||||
5. Set path
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||
```
|
||||
</details>
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -69,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations
|
||||
- `sharegpt:chat`: conversations
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -110,13 +179,70 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"article": "...", "summary": "..."}
|
||||
```
|
||||
|
||||
> Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch!
|
||||
- `alpaca_chat`: basic instruct for alpaca chat
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "response": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_qa`: question and answer for alpaca chat
|
||||
```json
|
||||
{"question": "...", "answer": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "response": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai
|
||||
```json
|
||||
{"message_1": "...", "message_2": "..."}
|
||||
```
|
||||
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
|
||||
```json
|
||||
{"system_prompt": "...", "question": "...", "response": "..."}
|
||||
```
|
||||
- `context_qa`: in context question answering from an article
|
||||
```json
|
||||
{"article": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||
```json
|
||||
{"article": "...", "unanswerable_question": "..."}
|
||||
```
|
||||
- `creative_acr.load_answer`: instruction and revision
|
||||
```json
|
||||
{"instruction": "...", "revision": "..."}
|
||||
```
|
||||
- `creative_acr.load_critique`: critique
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "..."}
|
||||
```
|
||||
- `creative_acr.load_revise`: critique and revise
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
|
||||
```
|
||||
- `pygmalion`: pygmalion
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
```json
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||
|
||||
|
||||
|
||||
### Config
|
||||
|
||||
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||
@@ -129,10 +255,18 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
||||
|
||||
- dataset
|
||||
```yaml
|
||||
sequence_len: 2048 # max token length for prompt
|
||||
|
||||
# huggingface repo
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4 # local or huggingface repo
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca # format from earlier
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: json
|
||||
data_files: data.jsonl # or json
|
||||
type: alpaca # format from earlier
|
||||
sequence_len: 2048 # max token length / prompt
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -142,6 +276,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
||||
bf16: true # require >=ampere
|
||||
fp16: true
|
||||
tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
||||
float16: true # use instead of fp16 when you don't want AMP
|
||||
```
|
||||
Note: Repo does not do 4-bit quantization.
|
||||
|
||||
@@ -169,12 +305,19 @@ base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Trust remote code for untrusted source
|
||||
trust_remote_code:
|
||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||
tokenizer_use_fast:
|
||||
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
@@ -195,10 +338,10 @@ tf32: true # require >=ampere
|
||||
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# this can be either a hf dataset, or relative path
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format OR format:prompt_style (chat/instruct)
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
data_files: # path to source data files
|
||||
shards: # number of shards to split data into
|
||||
|
||||
@@ -207,6 +350,8 @@ datasets:
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
@@ -258,20 +403,25 @@ wandb_log_model: # 'checkpoint'
|
||||
output_dir: ./completed-model
|
||||
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
logging_steps:
|
||||
save_steps:
|
||||
eval_steps:
|
||||
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
@@ -293,11 +443,27 @@ log_sweep_max_lr:
|
||||
optimizer:
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention: # require a100 for llama
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# llama only
|
||||
xpos_rope:
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -365,11 +531,16 @@ Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
--inference --lora_model_dir ./completed-model
|
||||
--inference --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
--inference --base_model ./completed-model
|
||||
--inference --base_model="./completed-model"
|
||||
```
|
||||
- Full weights finetune w/ a prompt from a text file:
|
||||
```bash
|
||||
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
||||
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
||||
```
|
||||
|
||||
### Merge LORA to base
|
||||
@@ -380,6 +551,12 @@ Add below flag to train command above
|
||||
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
||||
```
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
> Cuda out of memory
|
||||
@@ -387,6 +564,7 @@ Add below flag to train command above
|
||||
Please reduce any below
|
||||
- `micro_batch_size`
|
||||
- `eval_batch_size`
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
> RuntimeError: expected scalar type Float but found Half
|
||||
@@ -397,12 +575,41 @@ Try set `fp16: true`
|
||||
|
||||
Try to turn off xformers.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||
|
||||
```markdown
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
```
|
||||
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
|
||||
## Community Showcase
|
||||
|
||||
Open Access AI Collective
|
||||
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
|
||||
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
||||
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
||||
|
||||
PocketDoc Labs
|
||||
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
||||
|
||||
## Contributing 🤝
|
||||
|
||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||
|
||||
PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: 'NO'
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,40 +0,0 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- c_attn
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: pythia-1.4b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 32
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.0003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,41 +0,0 @@
|
||||
base_model: facebook/galactica-1.3b
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 32
|
||||
micro_batch_size: 16
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
tf32: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,39 +0,0 @@
|
||||
base_model: EleutherAI/gpt-neox-20b
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: nomic-ai/gpt4all-j-prompt-generations
|
||||
type: alpaca
|
||||
shards: 4
|
||||
shards_index: 0
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- query_key_value
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project: gpt4all-neox-20b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./gpt4all-neox-20b
|
||||
batch_size: 48
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
lr_scheduler: one_cycle
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,39 +0,0 @@
|
||||
base_model: huggyllama/llama-13b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
||||
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
||||
type: sharegpt
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.002
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./llama-13b-sharegpt
|
||||
batch_size: 64
|
||||
micro_batch_size: 2
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
eval_steps:
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience: 5
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,44 +0,0 @@
|
||||
base_model: huggyllama/llama-65b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
||||
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-65b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
micro_batch_size: 16
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca # original alpaca dataset
|
||||
type: alpaca
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-test
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
load_4bit: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
@@ -1,41 +0,0 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
micro_batch_size: 16
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca # original alpaca dataset
|
||||
type: alpaca
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-test
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
gptq: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
@@ -1,86 +0,0 @@
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# whether you are training a 4-bit quantized model
|
||||
load_4bit: true
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# this can be either a hf dataset, or relative path
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
|
||||
val_set_size: 0.04
|
||||
# if you want to use lora, leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
lora_model_dir:
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
max_packed_sequence_len: 1024
|
||||
# lora hyperparameters
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
# wandb configuration if your're using it
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
# where to save the finsihed model to
|
||||
output_dir: ./completed-model
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
# Use CUDA bf16
|
||||
bf16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
gradient_checkpointing: false
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
# specify a scheduler to use with the optimizer. only one_cycle is supported currently
|
||||
lr_scheduler:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention:
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
@@ -1,56 +0,0 @@
|
||||
base_model: stabilityai/stablelm-base-alpha-3b
|
||||
base_model_config: stabilityai/stablelm-base-alpha-3b
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: stable-alpaca-3b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./stable-alpaca-3b
|
||||
batch_size: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 100
|
||||
eval_steps: 50
|
||||
save_steps: 200
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.01
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
#tokens:
|
||||
# pad_token: "[PAD]"
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
||||
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_4bit: true
|
||||
gptq_groupsize: 128
|
||||
gptq_model_v1: false
|
||||
datasets:
|
||||
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
|
||||
- path: data/alpaca_reflect_pruned.jsonl
|
||||
type: reflection
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-reflect
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
flash_attention: true
|
||||
@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
|
||||
## Convert the JSON data files to JSONL.
|
||||
|
||||
```shell
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
```
|
||||
---
|
||||
|
||||
|
||||
20
docker-compose.yaml
Normal file
20
docker-compose.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
# version: '3.8'
|
||||
services:
|
||||
axolotl:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./docker/Dockerfile
|
||||
volumes:
|
||||
- .:/workspace/axolotl
|
||||
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
||||
# set environment variables
|
||||
environment:
|
||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# count: 1
|
||||
capabilities: [gpu]
|
||||
command: tail -f /dev/null
|
||||
@@ -2,19 +2,25 @@ ARG BASE_TAG=main-base
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
RUN python3 -m pip install -U --no-cache-dir pydantic
|
||||
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main"
|
||||
|
||||
RUN mkdir axolotl
|
||||
COPY . axolotl/
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN cd axolotl && \
|
||||
pip install -e .[int4]
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .; \
|
||||
fi
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
|
||||
@@ -9,7 +9,7 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH="2.0.0"
|
||||
ARG CUDA="cu118"
|
||||
ARG CUDA="118"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
|
||||
FROM base-builder AS flash-attn-builder
|
||||
@@ -52,6 +52,8 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
@@ -61,12 +63,12 @@ RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA_VERSION_BNB="118"
|
||||
ENV CUDA_VERSION_BNB=$CUDA_VERSION_BNB
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA_VERSION_BNB make cuda11x && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
@@ -75,7 +77,7 @@ FROM base-builder
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
||||
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .
|
||||
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
@@ -93,10 +95,6 @@ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/d
|
||||
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" && \
|
||||
pip3 install awscli && \
|
||||
RUN pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic
|
||||
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
60
examples/cerebras/qlora.yml
Normal file
60
examples/cerebras/qlora.yml
Normal file
@@ -0,0 +1,60 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- c_fc
|
||||
- c_attn
|
||||
- c_proj
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -23,7 +23,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: falcon-7b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -61,4 +61,3 @@ special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
|
||||
92
examples/falcon/config-7b-qlora.yml
Normal file
92
examples/falcon/config-7b-qlora.yml
Normal file
@@ -0,0 +1,92 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
gptq: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: QingyiSi/Alpaca-CoT
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
# hyperparameters from QLoRA paper Appendix B.2
|
||||
# "We find hyperparameters to be largely robust across datasets"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
# 0.1 for models up to 13B
|
||||
# 0.05 for 33B and 65B models
|
||||
lora_dropout: 0.05
|
||||
# add LoRA modules on all linear layers of the base model
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
# QLoRA paper Table 9
|
||||
# - 16 for 7b & 13b
|
||||
# - 32 for 33b, 64 for 64b
|
||||
# Max size tested on A6000
|
||||
# - 7b: 40
|
||||
# - 40b: 4
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
# QLoRA paper Table 9
|
||||
# - 2e-4 for 7b & 13b
|
||||
# - 1e-4 for 33b & 64b
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
@@ -23,7 +23,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: falcon-7b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -61,4 +61,3 @@ special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
|
||||
57
examples/gptj/qlora.yml
Normal file
57
examples/gptj/qlora.yml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -3,6 +3,6 @@
|
||||
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
|
||||
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
|
||||
|
||||
```
|
||||
|
||||
@@ -24,9 +24,9 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora-int4
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-7b-lora-int4
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
@@ -7,30 +7,28 @@ datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
type: jeopardy
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
sequence_len: 512
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: jeopardy-bot-7b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
@@ -48,11 +46,10 @@ eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -22,9 +22,9 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
16
examples/openllama-3b/README.md
Normal file
16
examples/openllama-3b/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# openllama-3b
|
||||
|
||||
Basic full tune
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
|
||||
```
|
||||
|
||||
LoRA
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
QLoRA
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
|
||||
```
|
||||
62
examples/openllama-3b/config.yml
Normal file
62
examples/openllama-3b/config.yml
Normal file
@@ -0,0 +1,62 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
float16: true
|
||||
bf16: false
|
||||
fp16: false
|
||||
tf32: false
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model_config: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -49,7 +49,7 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
61
examples/openllama-3b/qlora.yml
Normal file
61
examples/openllama-3b/qlora.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
9
examples/pythia-12b/README.md
Normal file
9
examples/pythia-12b/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Pythia 12B
|
||||
|
||||
- Single-GPU A100 only (?)
|
||||
|
||||
```shell
|
||||
python scripts/finetune.py examples/pythia-12b/config.yml
|
||||
```
|
||||
|
||||
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
|
||||
49
examples/pythia-12b/config.yml
Normal file
49
examples/pythia-12b/config.yml
Normal file
@@ -0,0 +1,49 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
device_map: auto
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
fp16: false
|
||||
float16: true
|
||||
tf32: true
|
||||
flash_optimum: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
gradient_checkpointing: true
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
collator_pad_to_longest: true
|
||||
@@ -1,36 +1,29 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
sequence_len: 512
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- query_key_value
|
||||
# - xxx
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project: pythia-1.4b-lora
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 48
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
@@ -39,3 +32,6 @@ tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 20
|
||||
logging_steps: 1
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: GPTNeoXTokenizer
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
|
||||
BIN
image/axolotl-badge-web.png
Normal file
BIN
image/axolotl-badge-web.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
3
requirements-dev.txt
Normal file
3
requirements-dev.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pre-commit
|
||||
black
|
||||
mypy
|
||||
@@ -4,16 +4,17 @@ bitsandbytes>=0.39.0
|
||||
addict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
black
|
||||
datasets
|
||||
accelerate>=0.19.0
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers
|
||||
optimum
|
||||
# qlora things
|
||||
bert-score==0.3.13
|
||||
evaluate==0.4.0
|
||||
rouge-score==0.1.2
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
numba
|
||||
|
||||
@@ -1,24 +1,38 @@
|
||||
"""Module to convert json file to jsonl"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import fire
|
||||
from typing import Optional
|
||||
|
||||
from axolotl.convert import (
|
||||
FileReader,
|
||||
FileWriter,
|
||||
JsonlSerializer,
|
||||
JsonParser,
|
||||
JsonToJsonlConverter,
|
||||
StdoutWriter,
|
||||
)
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.convert import *
|
||||
|
||||
|
||||
def main(
|
||||
input: Path,
|
||||
file: Path,
|
||||
output: Optional[Path] = None,
|
||||
to_stdout: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Convert a json file to jsonl
|
||||
"""
|
||||
|
||||
file_reader = FileReader()
|
||||
writer: Union[StdoutWriter, FileWriter]
|
||||
if to_stdout or output is None:
|
||||
writer = StdoutWriter()
|
||||
else:
|
||||
@@ -28,7 +42,7 @@ def main(
|
||||
|
||||
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||
|
||||
converter.convert(input, output)
|
||||
converter.convert(file, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -5,25 +7,28 @@ import random
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.validation import validate_config
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
@@ -31,68 +36,101 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
def choose_device(cfg):
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
return "cpu"
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device == "cuda":
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to finish): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||
return instruction
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=100,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = [file for file in path.glob("*.yml")]
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
if not yaml_files:
|
||||
raise ValueError(
|
||||
@@ -130,31 +168,37 @@ def train(
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, "r") as f:
|
||||
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k in kwargs:
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or cfg.strict is False:
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.gradient_accumulation_steps = (
|
||||
cfg.gradient_accumulation_steps // cfg.world_size
|
||||
)
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
@@ -163,26 +207,37 @@ def train(
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
validate_config(cfg)
|
||||
if cfg.tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# load the tokenizer first
|
||||
logging.info("loading tokenizer...")
|
||||
tokenizer = load_tokenizer(
|
||||
cfg.base_model_config,
|
||||
cfg.tokenizer_type,
|
||||
cfg
|
||||
)
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
|
||||
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
): # don't need to load dataset for these
|
||||
if not cfg.pretraining_dataset:
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
train_dataset = load_pretraining_dataset(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
logging.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
||||
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
||||
),
|
||||
tokenizer,
|
||||
)
|
||||
@@ -200,7 +255,6 @@ def train(
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
inference=("inference" in kwargs),
|
||||
)
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
@@ -213,9 +267,15 @@ def train(
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if "inference" in kwargs:
|
||||
if cfg.inference:
|
||||
logging.info("calling do_inference function")
|
||||
do_inference(cfg, model, tokenizer)
|
||||
prompter: Optional[str] = "AlpacaPrompter"
|
||||
if "prompter" in kwargs:
|
||||
if kwargs["prompter"] == "None":
|
||||
prompter = None
|
||||
else:
|
||||
prompter = kwargs["prompter"]
|
||||
do_inference(cfg, model, tokenizer, prompter=prompter)
|
||||
return
|
||||
|
||||
if "shard" in kwargs:
|
||||
@@ -237,9 +297,15 @@ def train(
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
signal.SIGINT,
|
||||
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
logging.info("Starting trainer...")
|
||||
@@ -252,20 +318,33 @@ def train(
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
resume_from_checkpoint = sorted_paths[-1]
|
||||
logging.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
export WANDB_MODE=offline
|
||||
export WANDB_CACHE_DIR=/workspace/data/wandb-cache
|
||||
mkdir -p $WANDB_CACHE_DIR
|
||||
|
||||
mkdir -p /workspace/data/huggingface-cache/{hub,datasets}
|
||||
export HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
export HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
export TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
export NCCL_P2P_DISABLE=1
|
||||
|
||||
nvidia-smi
|
||||
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
gpu_indices=$(seq 0 $((num_gpus - 1)) | paste -sd "," -)
|
||||
export CUDA_VISIBLE_DEVICES=$gpu_indices
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
apt-get update
|
||||
apt-get install -y build-essential ninja-build vim git-lfs
|
||||
git lfs install
|
||||
pip3 install --force-reinstall https://download.pytorch.org/whl/nightly/cu117/torch-2.0.0.dev20230301%2Bcu117-cp38-cp38-linux_x86_64.whl --index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
if [ -z "${TORCH_CUDA_ARCH_LIST}" ]; then # only set this if not set yet
|
||||
# this covers most common GPUs that the installed version of pytorch supports
|
||||
# python -c "import torch; print(torch.cuda.get_arch_list())"
|
||||
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
fi
|
||||
|
||||
# install flash-attn and deepspeed from pre-built wheels for this specific container b/c these take forever to install
|
||||
mkdir -p /workspace/wheels
|
||||
cd /workspace/wheels
|
||||
curl -L -O https://github.com/OpenAccess-AI-Collective/axolotl/raw/wheels/wheels/deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
|
||||
curl -L -O https://github.com/OpenAccess-AI-Collective/axolotl/raw/wheels/wheels/flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
|
||||
pip install deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
|
||||
pip install flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
|
||||
pip install "peft @ git+https://github.com/huggingface/peft.git@main" --force-reinstall --no-dependencies
|
||||
|
||||
cd /workspace/
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
cd axolotl
|
||||
pip install -e .[int4]
|
||||
mkdir -p ~/.cache/huggingface/accelerate/
|
||||
cp configs/accelerate/default_config.yaml ~/.cache/huggingface/accelerate/default_config.yaml
|
||||
10
setup.py
10
setup.py
@@ -1,7 +1,9 @@
|
||||
from setuptools import setup, find_packages
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
install_requires = []
|
||||
with open("./requirements.txt", "r") as requirements_file:
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
# don't include peft yet until we check the int4
|
||||
# need to manually install peft for now...
|
||||
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||
@@ -17,10 +19,10 @@ setup(
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"int4": [
|
||||
"gptq": [
|
||||
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"int4_triton": [
|
||||
"gptq_triton": [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"extras": [
|
||||
|
||||
@@ -1,47 +1,76 @@
|
||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
class FileReader:
|
||||
"""
|
||||
Reads a file and returns its contents as a string
|
||||
"""
|
||||
|
||||
def read(self, file_path):
|
||||
with open(file_path, "r") as file:
|
||||
with open(file_path, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
class FileWriter:
|
||||
"""
|
||||
Writes a string to a file
|
||||
"""
|
||||
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
|
||||
def write(self, content):
|
||||
with open(self.file_path, "w") as file:
|
||||
with open(self.file_path, "w", encoding="utf-8") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
class StdoutWriter:
|
||||
"""
|
||||
Writes a string to stdout
|
||||
"""
|
||||
|
||||
def write(self, content):
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
class JsonParser:
|
||||
"""
|
||||
Parses a string as JSON and returns the result
|
||||
"""
|
||||
|
||||
def parse(self, content):
|
||||
return json.loads(content)
|
||||
|
||||
|
||||
class JsonlSerializer:
|
||||
"""
|
||||
Serializes a list of JSON objects into a JSONL string
|
||||
"""
|
||||
|
||||
def serialize(self, data):
|
||||
lines = [json.dumps(item) for item in data]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class JsonToJsonlConverter:
|
||||
"""
|
||||
Converts a JSON file to JSONL
|
||||
"""
|
||||
|
||||
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
|
||||
self.file_reader = file_reader
|
||||
self.file_writer = file_writer
|
||||
self.json_parser = json_parser
|
||||
self.jsonl_serializer = jsonl_serializer
|
||||
|
||||
def convert(self, input_file_path, output_file_path):
|
||||
def convert(
|
||||
self, input_file_path, output_file_path
|
||||
): # pylint: disable=unused-argument
|
||||
content = self.file_reader.read(input_file_path)
|
||||
data = self.json_parser.parse(content)
|
||||
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||
|
||||
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
@@ -14,7 +16,14 @@ from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||
|
||||
|
||||
class TokenizedPromptDataset(IterableDataset):
|
||||
def __init__(
|
||||
"""
|
||||
Iterable dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
@@ -24,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
count = 0
|
||||
# Loop through the entire dataset
|
||||
for example in iterator:
|
||||
try:
|
||||
yield self.prompt_tokenizer.tokenize_prompt(example)
|
||||
count += 1
|
||||
except InvalidDataException:
|
||||
pass
|
||||
if count == 0:
|
||||
raise RuntimeError("Expected at least one datapoint in dataset.")
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
@@ -42,7 +55,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
tokenizer,
|
||||
datasets,
|
||||
@@ -82,10 +95,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if (
|
||||
not example_len
|
||||
or buffer_len + int(add_concat_token) + example_len
|
||||
> self.seq_length
|
||||
if not example_len or (
|
||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
@@ -95,9 +106,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if (
|
||||
labels.size() == input_ids.size()
|
||||
and attention_mask.size() == input_ids.size()
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
@@ -108,15 +118,25 @@ class ConstantLengthDataset(IterableDataset):
|
||||
logging.warning(
|
||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||
)
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
|
||||
if example:
|
||||
# FIXME
|
||||
# just going to drop data points that are too long
|
||||
if len(example["input_ids"]) <= self.seq_length:
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
if (
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
):
|
||||
attention_mask[0] = 0
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
"""Flash attention monkey patch for llama model"""
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
def forward(
|
||||
@@ -27,6 +25,7 @@ def forward(
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
@@ -74,7 +73,11 @@ def forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
@@ -82,35 +85,56 @@ def forward(
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
x_unpad,
|
||||
"nnz (three h d) -> nnz three h d",
|
||||
three=3,
|
||||
h=nheads,
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
x_unpad,
|
||||
cu_q_lens,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
return (
|
||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
233
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
233
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
logging.error("xformers not found! Please install it before trying to use it.")
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.models.llama.modeling_llama
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class XposRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scale_base=2048,
|
||||
use_xpos=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.scale_base = scale_base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer("scale", None)
|
||||
self.register_buffer("scale_cached", torch.ones(1))
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
||||
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
seq_len,
|
||||
):
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
||||
self.inv_freq
|
||||
)
|
||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
||||
|
||||
self.register_buffer("freqs_cached", freqs)
|
||||
|
||||
if self.scale is None:
|
||||
self.register_buffer(
|
||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
||||
)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
||||
|
||||
power = (t - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
||||
self.register_buffer("scale_cached", scale)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
||||
freqs = freqs[position_ids, :]
|
||||
if scale.shape[-1] != 1:
|
||||
scale = scale[position_ids, :]
|
||||
|
||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def replace_llama_rope_with_xpos_rope():
|
||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
@@ -7,8 +9,8 @@ def load(strategy, tokenizer, cfg):
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
fn = getattr(m, load_fn)
|
||||
return fn(tokenizer, cfg)
|
||||
except:
|
||||
pass
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
func = getattr(mod, load_fn)
|
||||
return func(tokenizer, cfg)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
@@ -1,21 +1,65 @@
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
InstructionPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.chat.value),
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaConcisePrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
|
||||
class AlpacaChatPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
self.prompt_style = PromptStyle.CHAT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
|
||||
class NoSystemPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Null Prompter with no system prompts
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
turn_format = "{instruction} {input} "
|
||||
turn_no_input_format = "{instruction} "
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for AlpacaQA
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"",
|
||||
@@ -23,9 +67,49 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.chat.value),
|
||||
class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for CamelAI datasets
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["message_1"],
|
||||
"",
|
||||
prompt["message_2"],
|
||||
)
|
||||
|
||||
|
||||
def load_concise(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaConcisePrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_camel_ai(tokenizer, cfg):
|
||||
return CamelAIPromptTokenizingStrategy(
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_no_prompt(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
UnpromptedPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
||||
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.instruct),
|
||||
AlpacaPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_no_prompt(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
120
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
120
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||
"""
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["output"],
|
||||
prompt["system"],
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
system,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt_w_system(
|
||||
system,
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class SystemDataPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = system + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = system + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for OpenOrca datasets
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"",
|
||||
prompt["response"],
|
||||
prompt["system_prompt"],
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return load_chat(tokenizer, cfg)
|
||||
|
||||
|
||||
def load_instruct(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_chat(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_open_orca(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
67
src/axolotl/prompt_strategies/context_qa.py
Normal file
67
src/axolotl/prompt_strategies/context_qa.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
# article, unanswerable_question, question, answer
|
||||
def load_404(tokenizer, cfg):
|
||||
return AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
AlpacaContextPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaContextPromptTokenizingStrategy(
|
||||
AlpacaContextPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"Use the following contextual information to concisely answer the question.\n"
|
||||
)
|
||||
system_no_input_prompt = (
|
||||
"Use the following contextual information to concisely answer the question.\n"
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question and answer
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"] + "\n===\n" + prompt["question"],
|
||||
"",
|
||||
prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
InstructionPromptTokenizingStrategy
|
||||
):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question that can't be answered
|
||||
from the context and a default response to that effect
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
|
||||
"",
|
||||
"The context provided does not contain any information about your inquiry. "
|
||||
"Therefore, I'm unable to answer your question based on the given context.",
|
||||
)
|
||||
@@ -1,11 +1,18 @@
|
||||
from typing import Union, Generator
|
||||
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
||||
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
|
||||
|
||||
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Creative Answering
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
question = prompt["instruction"]
|
||||
answer = prompt[
|
||||
"revision"
|
||||
@@ -18,6 +25,10 @@ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrat
|
||||
|
||||
|
||||
class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Creative Critique
|
||||
"""
|
||||
|
||||
user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
|
||||
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
|
||||
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
|
||||
@@ -49,12 +60,16 @@ Question: {question}
|
||||
Answer: {answer}
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
scores = yaml.dump(
|
||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["scores"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
critiques = yaml.dump(
|
||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["critiques"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
evaluation = scores + critiques
|
||||
question = prompt["instruction"]
|
||||
@@ -67,6 +82,10 @@ Answer: {answer}
|
||||
|
||||
|
||||
class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Creative Revise
|
||||
"""
|
||||
|
||||
user_prompt = """Definitions:
|
||||
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
|
||||
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
|
||||
@@ -81,12 +100,16 @@ Evaluation:
|
||||
{evaluation}
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
scores = yaml.dump(
|
||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["scores"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
critiques = yaml.dump(
|
||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["critiques"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
evaluation = scores + critiques
|
||||
question = prompt["instruction"]
|
||||
@@ -101,13 +124,19 @@ Evaluation:
|
||||
|
||||
|
||||
class CreativePrompterBase:
|
||||
"""
|
||||
Base class for Creative Prompters
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None,
|
||||
input: Union[ # pylint: disable=redefined-builtin, unused-argument
|
||||
None, str
|
||||
] = None,
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
if self.system_prompt:
|
||||
@@ -120,30 +149,51 @@ class CreativePrompterBase:
|
||||
|
||||
|
||||
class CreativeAnswerPrompter(CreativePrompterBase):
|
||||
"""
|
||||
Prompter for Creative Answering
|
||||
"""
|
||||
|
||||
system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
|
||||
|
||||
|
||||
class CreativeCritiquePrompter(CreativePrompterBase):
|
||||
"""
|
||||
Prompter for Creative Critique
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
|
||||
class CreativeRevisePrompter(CreativePrompterBase):
|
||||
"""
|
||||
Prompter for Creative Revise
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
|
||||
def load_answer(tokenizer, cfg):
|
||||
return CreativeAnsweringPromptTokenizingStrategy(
|
||||
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeAnswerPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_critique(tokenizer, cfg):
|
||||
return CreativeCritiquePromptTokenizingStrategy(
|
||||
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeCritiquePrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_revise(tokenizer, cfg):
|
||||
return CreativeRevisePromptTokenizingStrategy(
|
||||
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeRevisePrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
@@ -1,29 +1,34 @@
|
||||
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Generator
|
||||
from typing import Generator, List, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
PromptTokenizingStrategy,
|
||||
parse_tokenized_to_result,
|
||||
tokenize_prompt_default,
|
||||
)
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
|
||||
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
bot_prefix_token_ids = []
|
||||
"""
|
||||
Tokenizing strategy for Pygmalion.
|
||||
"""
|
||||
|
||||
bot_prefix_token_ids: List[int] = []
|
||||
|
||||
def __init__(self, prompter, tokenizer, *args, **kwargs):
|
||||
super().__init__(prompter, tokenizer)
|
||||
super().__init__(prompter, tokenizer, *args, **kwargs)
|
||||
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
|
||||
self.bot_prefix_token_ids = res["input_ids"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
current_len = 0
|
||||
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
||||
result, current_len = tokenize_prompt_default()
|
||||
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
||||
role, message = part
|
||||
if role == "system":
|
||||
prefix = "<|system|>"
|
||||
@@ -61,45 +66,29 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
else:
|
||||
logging.warning(f"unknown role in conversation: {role}")
|
||||
res = defaultdict(lambda: [])
|
||||
input_ids = res["input_ids"]
|
||||
input_len = len(input_ids)
|
||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||
result["attention_mask"][current_len : current_len + input_len] = [
|
||||
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
||||
]
|
||||
result["labels"][current_len : current_len + input_len] = labels
|
||||
current_len += input_len
|
||||
return result
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class PygmalionPrompter:
|
||||
"""
|
||||
Prompter for Pygmalion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
||||
def build_prompt(
|
||||
self, source, *args, **kwargs # pylint: disable=unused-argument
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
for msg in source:
|
||||
yield msg["role"], msg["value"]
|
||||
|
||||
|
||||
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
||||
"""
|
||||
|
||||
# title, text, explanation
|
||||
def get_conversation_thread(self, prompt):
|
||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
||||
return [
|
||||
{"from": "human", "value": "Tell me a joke."},
|
||||
{"from": "gpt", "value": title + prompt["text"]},
|
||||
{"from": "human", "value": "Why is that joke funny?"},
|
||||
{"from": "gpt", "value": prompt["explanation"]},
|
||||
]
|
||||
67
src/axolotl/prompt_strategies/sharegpt_simple.py
Normal file
67
src/axolotl/prompt_strategies/sharegpt_simple.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_guanaco(tokenizer, cfg):
|
||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
return turns
|
||||
|
||||
|
||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
@@ -1,24 +1,33 @@
|
||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
pass
|
||||
"""
|
||||
Exception raised when the data is invalid
|
||||
"""
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
"""
|
||||
Abstract class for tokenizing strategies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
@@ -35,59 +44,21 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def tokenize_prompt(self, prompt):
|
||||
pass
|
||||
|
||||
@functools.cache
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
return False
|
||||
|
||||
@functools.cache
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
return False
|
||||
|
||||
|
||||
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction, input, response = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction, input, response):
|
||||
return next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
response,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
@@ -111,8 +82,64 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
return result
|
||||
|
||||
|
||||
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(
|
||||
self, prompt
|
||||
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response # pylint: disable=redefined-builtin
|
||||
):
|
||||
return next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
response,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Alpaca prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
@@ -121,7 +148,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Alpaca Multiple Choice prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
|
||||
@@ -130,7 +161,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
|
||||
|
||||
|
||||
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Jeopardy prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
prompt["category"],
|
||||
@@ -139,7 +174,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for OpenAssistant prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["INSTRUCTION"],
|
||||
"",
|
||||
@@ -148,7 +187,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
||||
|
||||
|
||||
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for SummarizeTLDR prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"],
|
||||
"",
|
||||
@@ -157,7 +200,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
||||
|
||||
|
||||
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for GPTeacher prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
@@ -166,7 +213,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for NomicGPT4All prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["prompt"],
|
||||
"",
|
||||
@@ -175,28 +226,34 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> str:
|
||||
return prompt["text"]
|
||||
"""
|
||||
Tokenizing strategy for Completion prompts.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, None, None)
|
||||
full_prompt = self._build_full_prompt(prompt["text"], None, None)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction, input, response):
|
||||
return next(iter(self.prompter.build_prompt(instruction)))
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response
|
||||
): # pylint: disable=redefined-builtin
|
||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||
|
||||
|
||||
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Reflection prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
(
|
||||
instruction,
|
||||
input,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
@@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, output, reflection, corrected
|
||||
): # pylint: disable=redefined-builtin
|
||||
return next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
@@ -236,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
)
|
||||
)
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True):
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
@@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
|
||||
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Alpaca Reflection prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
@@ -268,20 +331,19 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
current_len = 0
|
||||
result, current_len = tokenize_prompt_default()
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
try:
|
||||
for i, part in enumerate(
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
@@ -289,7 +351,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
part = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=True
|
||||
part.strip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
@@ -300,32 +364,39 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
part = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistent response, should end with an eos token
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=True, strip_bos_token=True
|
||||
part.strip(),
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "SYSTEM:":
|
||||
part = part[1] # Ignore the system role from preamble
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
logging.warning("unhandled role: " + part[0])
|
||||
else:
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
input_ids = res["input_ids"]
|
||||
input_len = len(input_ids)
|
||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||
result["attention_mask"][current_len : current_len + input_len] = [
|
||||
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
||||
]
|
||||
result["labels"][current_len : current_len + input_len] = labels
|
||||
current_len += input_len
|
||||
logging.warning(f"unhandled role: {part[0]}")
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as e:
|
||||
raise InvalidDataException(str(e))
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
@@ -349,3 +420,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Returns the default values for the tokenize prompt function
|
||||
"""
|
||||
|
||||
result: Dict[str, List[int]] = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
current_len = 0
|
||||
return result, current_len
|
||||
|
||||
|
||||
def parse_tokenized_to_result(
|
||||
result: Dict[str, List[int]],
|
||||
current_len: int,
|
||||
res: Dict[str, List[int]],
|
||||
labels: List[int],
|
||||
pad_token_id: Union[int, None] = None,
|
||||
) -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
|
||||
"""
|
||||
|
||||
input_ids = res["input_ids"]
|
||||
input_len = len(input_ids)
|
||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||
result["attention_mask"][current_len : current_len + input_len] = [
|
||||
1 if x != pad_token_id else 0 for x in input_ids
|
||||
]
|
||||
result["labels"][current_len : current_len + input_len] = labels
|
||||
current_len += input_len
|
||||
|
||||
return result, current_len
|
||||
|
||||
@@ -1,110 +1,155 @@
|
||||
import copy
|
||||
"""Module containing prompters"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any, Union, Generator
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
instruct = "instruct"
|
||||
chat = "chat"
|
||||
"""
|
||||
Enum for prompt styles
|
||||
"""
|
||||
|
||||
INSTRUCT = "instruct"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
class AlpacaPrompter:
|
||||
"""
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
prompt_style = None
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
|
||||
def __init__(self, prompt_style=PromptStyle.instruct.value):
|
||||
self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
|
||||
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
||||
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
def match_prompt_style(self):
|
||||
if self.prompt_style == PromptStyle.instruct.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
self.turn_no_input_format = (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.prompt_no_input = (
|
||||
self.system_no_input_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.response_split = "### Response:"
|
||||
if self.prompt_style == PromptStyle.chat.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
)
|
||||
self.prompt_no_input = (
|
||||
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = self.prompt_input.format(instruction=instruction, input=input)
|
||||
res = self.system_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
else:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for alpaca no system prompt
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
|
||||
class JeopardyPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for Jeopardy
|
||||
"""
|
||||
|
||||
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
|
||||
|
||||
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for multiple choice explain
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"Choose the answer that best answers the question. Explain your reasoning."
|
||||
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
||||
)
|
||||
system_no_input_prompt = (
|
||||
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
||||
)
|
||||
|
||||
|
||||
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
||||
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
|
||||
"""
|
||||
Prompter for multiple choice concise
|
||||
"""
|
||||
|
||||
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
||||
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class SummarizeTLDRPrompter(AlpacaPrompter):
|
||||
prompt_no_input = (
|
||||
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
)
|
||||
"""
|
||||
Prompter for summarize TLDR
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
"""
|
||||
Prompter for completion
|
||||
"""
|
||||
|
||||
def build_prompt(
|
||||
self, instruction: str, input=None, output=None
|
||||
self,
|
||||
instruction: str,
|
||||
input=None, # pylint: disable=redefined-builtin, unused-argument
|
||||
output=None, # pylint: disable=unused-argument
|
||||
) -> Generator[str, None, None]:
|
||||
yield instruction
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.strip()
|
||||
|
||||
|
||||
class GPTeacherPrompter(AlpacaPrompter):
|
||||
...
|
||||
"""
|
||||
Prompter for GPTeacher
|
||||
"""
|
||||
|
||||
|
||||
class NomicGPT4AllPrompter(AlpacaPrompter):
|
||||
...
|
||||
"""
|
||||
Prompter for NomicGPT4All
|
||||
"""
|
||||
|
||||
|
||||
class ReflectAlpacaPrompter:
|
||||
"""
|
||||
Prompter for ReflectAlpaca
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
||||
|
||||
@@ -120,7 +165,7 @@ class ReflectAlpacaPrompter:
|
||||
self.match_prompt_style()
|
||||
|
||||
def match_prompt_style(self):
|
||||
if self.prompt_style == PromptStyle.instruct.value:
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
@@ -131,7 +176,7 @@ class ReflectAlpacaPrompter:
|
||||
)
|
||||
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
||||
self.response_split = "### Final Response:"
|
||||
if self.prompt_style == PromptStyle.chat.value:
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
)
|
||||
@@ -146,7 +191,7 @@ class ReflectAlpacaPrompter:
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
@@ -159,14 +204,13 @@ class ReflectAlpacaPrompter:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
if output and reflection and corrected:
|
||||
label = self.agent_label.format(
|
||||
output=output, reflection=reflection, corrected=corrected
|
||||
output=output,
|
||||
reflection=reflection,
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
@@ -187,18 +231,18 @@ class Conversation:
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
sep2: Optional[str] = None
|
||||
|
||||
def get_prompt(self) -> Generator[str, None, None]:
|
||||
seps = [self.sep, self.sep2]
|
||||
preamble = self.system + seps[0]
|
||||
yield preamble
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
||||
# seps = [self.sep, self.sep2]
|
||||
preamble = self.system + self.sep
|
||||
yield ("SYSTEM:", preamble)
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield (role + ":", " " + message)
|
||||
else:
|
||||
logging.warning("role with empty message: " + role)
|
||||
yield (role + ":",)
|
||||
logging.warning(f"role with empty message: {role}")
|
||||
yield (role + ":", "")
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
@@ -215,32 +259,35 @@ class Conversation:
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
conv_vicuna_v1_1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
|
||||
class ShareGPTPrompter:
|
||||
def __init__(self, prompt_style=None):
|
||||
if prompt_style != PromptStyle.chat.value:
|
||||
raise Exception(
|
||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
||||
if prompt_style != PromptStyle.CHAT.value:
|
||||
raise ValueError(
|
||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||
)
|
||||
system: str = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
)
|
||||
self._conversation = Conversation(
|
||||
system=system,
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
# def match_prompt_style(self):
|
||||
# if self.prompt_style == PromptStyle.chat.value:
|
||||
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||
# self.response_split = "ASSISTANT:"
|
||||
|
||||
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
# ignore the system prompt if provided
|
||||
if source[0]["from"] == "system":
|
||||
source.pop(0)
|
||||
@@ -250,7 +297,7 @@ class ShareGPTPrompter:
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
|
||||
conv = conv_vicuna_v1_1.copy()
|
||||
conv = self._conversation.copy()
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
try:
|
||||
@@ -261,9 +308,9 @@ class ShareGPTPrompter:
|
||||
):
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as e:
|
||||
except IndexError as err:
|
||||
# sometimes there is a bing or system chat
|
||||
raise e
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for j, sentence in enumerate(source):
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
"""Callbacks for Trainer class"""
|
||||
|
||||
import os
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import (
|
||||
Seq2SeqTrainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
TrainerState,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the PEFT adapter"""
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -19,10 +23,47 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the BetterTransformer wrapped model"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
if control.should_save:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
model = BetterTransformer.reverse(kwargs["model"])
|
||||
model.save_pretrained(checkpoint_folder)
|
||||
# FIXME - need to cleanup old checkpoints
|
||||
|
||||
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
@@ -1,42 +1,38 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
load_from_disk,
|
||||
load_dataset,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
concatenate_datasets,
|
||||
DatasetDict,
|
||||
)
|
||||
import torch
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
SummarizeTLDRPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
CompletionPrompter,
|
||||
GPTeacherPrompter,
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
ShareGPTPrompter,
|
||||
JeopardyPrompter,
|
||||
CompletionPrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,11 +41,13 @@ def load_tokenized_prepared_datasets(
|
||||
) -> DatasetDict:
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
ds_hash = str(
|
||||
md5(
|
||||
md5( # nosec
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
@@ -65,10 +63,11 @@ def load_tokenized_prepared_datasets(
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
@@ -80,44 +79,80 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||
logging.info("Loading raw datasets...")
|
||||
|
||||
if cfg.seed:
|
||||
seed = cfg.seed
|
||||
else:
|
||||
logging.info("No seed provided, using default seed of 42")
|
||||
seed = 42
|
||||
|
||||
datasets = []
|
||||
# pylint: disable=invalid-name
|
||||
for d in cfg.datasets:
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
|
||||
load_dataset(
|
||||
d.path,
|
||||
streaming=True,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
if Path(d.path).exists():
|
||||
ds: Dataset = load_dataset(
|
||||
"json", data_files=d.path, streaming=False, split=None
|
||||
)
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
if d.data_files:
|
||||
ds: Dataset = load_dataset(
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
streaming=False,
|
||||
data_files=d.data_files,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token)
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
streaming=False,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||
repo_id=d.path,
|
||||
repo_type="dataset",
|
||||
filename=d.data_files,
|
||||
)
|
||||
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||
if not ds:
|
||||
raise Exception("unhandled dataset load")
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -218,13 +253,21 @@ def load_tokenized_prepared_datasets(
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
suffix = ""
|
||||
if ":load_" in d.type:
|
||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||
logging.error(
|
||||
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
||||
|
||||
samples = []
|
||||
samples: List[int] = []
|
||||
for d in datasets:
|
||||
samples = samples + [i for i in d]
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||
samples = samples + list(d)
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
||||
@@ -242,8 +285,10 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||
) -> (Dataset, Dataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -256,13 +301,15 @@ def load_prepare_datasets(
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||
ds_hash = str(
|
||||
md5(
|
||||
md5( # nosec
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ seed
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
@@ -282,10 +329,11 @@ def load_prepare_datasets(
|
||||
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
@@ -319,7 +367,7 @@ def load_prepare_datasets(
|
||||
logging.info(
|
||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
||||
)
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
|
||||
# filter out bad data
|
||||
dataset = Dataset.from_list(
|
||||
@@ -343,7 +391,8 @@ def load_prepare_datasets(
|
||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
private=True,
|
||||
)
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
@@ -355,11 +404,131 @@ def load_prepare_datasets(
|
||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||
)
|
||||
dataset = dataset.shard(
|
||||
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
||||
num_shards=cfg.dataset_shard_num,
|
||||
index=cfg.dataset_shard_idx,
|
||||
)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
if cfg.val_set_size:
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
else:
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(tokenizer, max_tokens, examples):
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
# Convert to PyTorch tensors
|
||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||
new_input_ids = []
|
||||
new_attention_mask = []
|
||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||
for i, _ in enumerate(input_ids):
|
||||
input_ids[i] = torch.cat(
|
||||
(
|
||||
input_ids[i],
|
||||
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||
|
||||
# Concatenate tokens so that their lengths are less than max_tokens
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
for ids, mask in zip(input_ids, attention_mask):
|
||||
if buffer_input_ids.numel() == max_tokens:
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
else:
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
|
||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
|
||||
ret = {
|
||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||
"labels": [seq.tolist() for seq in new_input_ids],
|
||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||
}
|
||||
|
||||
logging.debug(len(ret["input_ids"]))
|
||||
return ret
|
||||
|
||||
|
||||
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
dataset = load_dataset(path, streaming=True, split="train")
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
# TODO dynamically figure out which columns/features to remove
|
||||
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
||||
return dataset
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Module containing the DictDefault class"""
|
||||
|
||||
from addict import Dict
|
||||
|
||||
|
||||
|
||||
@@ -1,52 +1,53 @@
|
||||
"""Module for models and model loading"""
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import (
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
AutoConfig,
|
||||
BitsAndBytesConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
except:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftModel, PeftConfig
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from transformers import PreTrainedTokenizer
|
||||
from peft import PeftConfig # noqa: F401
|
||||
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
use_fast = True # this is the default
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
|
||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
@@ -54,7 +55,10 @@ def load_tokenizer(
|
||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
if tokenizer.__class__.__name__ in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
@@ -62,8 +66,8 @@ def load_tokenizer(
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, v in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: v})
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: val})
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(list(cfg.tokens))
|
||||
|
||||
@@ -71,39 +75,62 @@ def load_tokenizer(
|
||||
|
||||
|
||||
def load_model(
|
||||
base_model,
|
||||
base_model_config,
|
||||
model_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter="lora",
|
||||
inference=False,
|
||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
|
||||
if is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
if cfg.bf16:
|
||||
logging.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
MEM_TOKEN,
|
||||
patch_llama_with_landmark_attn,
|
||||
)
|
||||
|
||||
logging.info("patching with landmark attention")
|
||||
patch_llama_with_landmark_attn()
|
||||
|
||||
# Note: This might overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
)
|
||||
|
||||
logging.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16:
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
@@ -114,12 +141,21 @@ def load_model(
|
||||
)
|
||||
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
from peft import prepare_model_for_int8_training
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise e
|
||||
except Exception as err:
|
||||
logging.exception(err)
|
||||
raise err
|
||||
|
||||
try:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
except ImportError:
|
||||
# For backward compatibility
|
||||
from peft import (
|
||||
prepare_model_for_int8_training as prepare_model_for_kbit_training,
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -130,7 +166,7 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
try:
|
||||
if cfg.gptq and is_llama_derived_model:
|
||||
if cfg.gptq and cfg.is_llama_derived_model:
|
||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -155,7 +191,7 @@ def load_model(
|
||||
"unable to find a cached model file, this will likely fail..."
|
||||
)
|
||||
model_path = str(cache_model_path)
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
model_path = cfg.base_model
|
||||
model, _ = load_llama_model_4bit_low_ram(
|
||||
base_model_config if base_model_config else base_model,
|
||||
@@ -168,9 +204,13 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -203,21 +243,37 @@ def load_model(
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif model_type:
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
):
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
@@ -225,20 +281,21 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
logging.error(
|
||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||
)
|
||||
logging.exception(e)
|
||||
logging.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -246,12 +303,23 @@ def load_model(
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
if (
|
||||
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
||||
and not cfg.gptq
|
||||
and (load_in_8bit or cfg.load_in_4bit)
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings
|
||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
logging.warning(
|
||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, adapter)
|
||||
|
||||
@@ -261,14 +329,14 @@ def load_model(
|
||||
if cfg.gptq:
|
||||
# Scales to half
|
||||
logging.info("Fitting 4bit scales and zeros to half")
|
||||
for n, m in model.named_modules():
|
||||
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
||||
type(m)
|
||||
for _, module in model.named_modules():
|
||||
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
||||
type(module)
|
||||
):
|
||||
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
||||
module.zeros = module.zeros.half()
|
||||
module.scales = module.scales.half()
|
||||
module.bias = module.bias.half()
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
@@ -278,8 +346,8 @@ def load_model(
|
||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||
# so let's only set it for the 4bit, see
|
||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||
setattr(model, 'is_parallelizable', True)
|
||||
setattr(model, 'model_parallel', True)
|
||||
setattr(model, "is_parallelizable", True)
|
||||
setattr(model, "model_parallel", True)
|
||||
|
||||
requires_grad = []
|
||||
for name, param in model.named_parameters(recurse=True):
|
||||
@@ -289,6 +357,9 @@ def load_model(
|
||||
logging.warning("there are no parameters that require gradient updates")
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
@@ -308,11 +379,7 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import (
|
||||
AdaptionPromptConfig,
|
||||
get_peft_model,
|
||||
PeftModel,
|
||||
)
|
||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
@@ -325,7 +392,6 @@ def load_llama_adapter(model, cfg):
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
device_map=cfg.device_map,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
else:
|
||||
@@ -357,11 +423,7 @@ def find_all_linear_names(bits, model):
|
||||
def load_lora(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
PeftModel,
|
||||
)
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
@@ -391,8 +453,7 @@ def load_lora(model, cfg):
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
device_map=cfg.device_map,
|
||||
# torch_dtype=torch.float16,
|
||||
is_trainable=not cfg.inference,
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
173
src/axolotl/utils/sampler.py
Normal file
173
src/axolotl/utils/sampler.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# pylint: skip-file
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[int] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
|
||||
while True:
|
||||
# binary search [l, r)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
m = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + m], c, n):
|
||||
left = m
|
||||
else:
|
||||
right = m
|
||||
|
||||
# use length l
|
||||
batch = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
assert len(batch) <= n
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
class MultipackDistributedBatchSampler(Sampler):
|
||||
"""Unpadded length sampling using Multipack.
|
||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_max_length: int,
|
||||
lengths: List[int],
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
seed: int = 0,
|
||||
):
|
||||
# Get rank
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.seed = seed
|
||||
|
||||
self.batch_max_length = batch_max_length
|
||||
self.lengths = lengths
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(
|
||||
len(self.lengths)
|
||||
)
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
c=self.batch_max_length,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
batches = [indices[batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches()
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
@@ -1,7 +1,16 @@
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
"""Module for custom LRScheduler class"""
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
|
||||
class InterpolatingLogScheduler(LRScheduler):
|
||||
"""
|
||||
A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
|
||||
"""A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
|
||||
@@ -19,7 +28,9 @@ class InterpolatingLogScheduler(LRScheduler):
|
||||
self.num_steps = num_steps
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
|
||||
self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name
|
||||
1 / (num_steps - 1)
|
||||
)
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
@@ -34,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
|
||||
lrs = [self.max_lr for base_lr in self.base_lrs]
|
||||
|
||||
return lrs
|
||||
|
||||
|
||||
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
|
||||
current_step: int,
|
||||
*,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float
|
||||
):
|
||||
if current_step < num_warmup_steps:
|
||||
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(
|
||||
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
||||
)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_quadratic_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (`int`):
|
||||
The total number of training steps.
|
||||
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
lr_lambda = partial(
|
||||
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from termcolor import colored
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def check_dataset_labels(dataset, tokenizer):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
@@ -17,7 +21,7 @@ def check_example_labels(example, tokenizer):
|
||||
# You can compare the input_ids and labels element-wise
|
||||
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
||||
colored_tokens = []
|
||||
for i, (input_id, label_id, mask) in enumerate(
|
||||
for _, (input_id, label_id, mask) in enumerate(
|
||||
zip(input_ids, labels, attention_mask)
|
||||
):
|
||||
decoded_input_token = tokenizer.decode(input_id)
|
||||
@@ -30,3 +34,5 @@ def check_example_labels(example, tokenizer):
|
||||
|
||||
logging.info(" ".join(colored_tokens))
|
||||
logging.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
@@ -1,28 +1,204 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback, Trainer
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||
from axolotl.utils.callbacks import (
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
from axolotl.utils.sampler import MultipackDistributedBatchSampler
|
||||
from axolotl.utils.schedulers import (
|
||||
InterpolatingLogScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
|
||||
IGNORE_LABEL_ID = -100
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(Trainer):
|
||||
def _find_multiple(val1, val2):
|
||||
return (-(val1 // -val2)) * val2
|
||||
|
||||
|
||||
def batch_to_tensor(batch, pad_id=0, dtype=torch.long, loss_dtype=torch.bfloat16):
|
||||
# Pad an unused item to reach multiple of 64, for faster GEMM
|
||||
pad_cur_len = sum(list(batch["length"]))
|
||||
pad_len = _find_multiple(pad_cur_len, 64) - pad_cur_len
|
||||
|
||||
if pad_len > 0:
|
||||
assert pad_len < 64
|
||||
|
||||
batch["input_ids"].append([pad_id] * pad_len)
|
||||
batch["labels"].append([pad_id] * pad_len)
|
||||
batch["attention_mask"].append([0] * pad_len)
|
||||
batch["length"].append(pad_len)
|
||||
|
||||
# seqlen
|
||||
batch_lengths = torch.tensor(list(batch["length"]), dtype=torch.int32, device="cpu")
|
||||
|
||||
max_seqlen = torch.max(batch_lengths)
|
||||
cu_seqlens = torch.nn.functional.pad(
|
||||
batch_lengths.cumsum(-1, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
# nz elements
|
||||
nz_num = cu_seqlens[-1]
|
||||
nz_input_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
|
||||
nz_position_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
|
||||
nz_shifted_label_ids = torch.zeros(
|
||||
(nz_num,), dtype=dtype, pin_memory=True, device="cpu"
|
||||
)
|
||||
nz_shifted_loss_weights = torch.zeros(
|
||||
(nz_num,), dtype=loss_dtype, pin_memory=True, device="cpu"
|
||||
)
|
||||
|
||||
index = 0
|
||||
for token_list, length, labels_list in zip(
|
||||
batch["input_ids"], batch["length"], batch["labels"]
|
||||
):
|
||||
tokens = torch.tensor(token_list, dtype=dtype, device="cpu")
|
||||
position_ids = torch.arange(length, dtype=dtype, device="cpu")
|
||||
|
||||
# Input IDs & shifted labels
|
||||
# shifted_label_ids = torch.where(masks, tokens, IGNORE_LABEL_ID)
|
||||
shifted_label_ids = labels_list
|
||||
shifted_label_ids = torch.nn.functional.pad(
|
||||
shifted_label_ids[1:], (0, 1), "constant", IGNORE_LABEL_ID
|
||||
)
|
||||
|
||||
nz_input_ids[index : index + length] = tokens
|
||||
nz_position_ids[index : index + length] = position_ids
|
||||
nz_shifted_label_ids[index : index + length] = shifted_label_ids
|
||||
|
||||
# Loss weights
|
||||
mask_count = sum(1 for label in labels_list[1:] if label != IGNORE_LABEL_ID)
|
||||
loss_weight = (
|
||||
1 / mask_count if mask_count > 0 else 0
|
||||
) # Avoid division by zero for paddings
|
||||
|
||||
nz_shifted_loss_weights[index : index + length] = loss_weight
|
||||
|
||||
index += length
|
||||
|
||||
# inputs
|
||||
return {
|
||||
"max_seqlen": max_seqlen,
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"nz_input_ids": nz_input_ids,
|
||||
"nz_position_ids": nz_position_ids,
|
||||
"nz_shifted_label_ids": nz_shifted_label_ids,
|
||||
"nz_shifted_loss_weights": nz_shifted_loss_weights,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
lengths = np.array([len(sample["input_ids"]) for sample in self.train_dataset])
|
||||
return MultipackDistributedBatchSampler(
|
||||
batch_max_length=self.args.per_device_train_batch_size
|
||||
* self.args.max_seq_length,
|
||||
lengths=lengths,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
lengths = np.array([len(sample["input_ids"]) for sample in eval_dataset])
|
||||
return MultipackDistributedBatchSampler(
|
||||
batch_max_length=self.args.per_device_eval_batch_size
|
||||
* self.args.max_seq_length,
|
||||
lengths=lengths,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
num_training_steps = num_training_steps
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
@@ -50,19 +226,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
save_steps = cfg.save_steps
|
||||
eval_steps = cfg.eval_steps
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
|
||||
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
if cfg.gradient_checkpointing is not None:
|
||||
|
||||
if cfg.seed:
|
||||
training_arguments_kwargs["seed"] = cfg.seed
|
||||
|
||||
if cfg.gradient_checkpointing:
|
||||
if cfg.gptq:
|
||||
from alpaca_lora_4bit.gradient_checkpointing import (
|
||||
apply_gradient_checkpointing,
|
||||
@@ -85,6 +263,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
||||
|
||||
if cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
||||
|
||||
# deepspeed
|
||||
if (
|
||||
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
|
||||
@@ -97,7 +278,25 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
# TODO search Path("./") for one
|
||||
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
||||
|
||||
training_args = transformers.TrainingArguments(
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
||||
if cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
||||
if cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
||||
|
||||
if cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps * cfg.num_epochs,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
@@ -107,18 +306,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||
save_strategy="steps" if save_steps else "epoch",
|
||||
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=save_steps,
|
||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True
|
||||
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
|
||||
and cfg.val_set_size > 0
|
||||
and save_steps is not None
|
||||
and save_steps % eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
else False,
|
||||
load_best_model_at_end=(
|
||||
cfg.load_best_model_at_end is not False
|
||||
and cfg.val_set_size > 0
|
||||
and cfg.save_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
)
|
||||
or False,
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
@@ -140,7 +340,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if (
|
||||
cfg.optimizer == "adamw_bnb_8bit"
|
||||
and not cfg.gptq
|
||||
and not "deepspeed" in training_arguments_kwargs
|
||||
and "deepspeed" not in training_arguments_kwargs
|
||||
and not cfg.fsdp
|
||||
):
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
@@ -206,9 +406,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
||||
if cfg.local_rank == 0 and cfg.adapter in [
|
||||
"lora",
|
||||
"qlora",
|
||||
]: # only save in rank 0
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True,
|
||||
}
|
||||
@@ -217,10 +423,30 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
|
||||
logging.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
else transformers.Trainer
|
||||
else AxolotlTrainer
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=model,
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
"""Module for validating config files"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
logging.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if cfg.load_4bit:
|
||||
raise ValueError(
|
||||
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||
@@ -38,9 +52,59 @@ def validate_config(cfg):
|
||||
)
|
||||
|
||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
logging.warning(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
)
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||
logging.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".")[0]) < 2:
|
||||
logging.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
logging.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adamw w bf16
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@@ -13,3 +15,5 @@ def setup_wandb_env_vars(cfg):
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
|
||||
"input": "Words: ['Hello', 'world'].",
|
||||
"output": "['world', 'Hello']"
|
||||
},
|
||||
{
|
||||
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
|
||||
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
|
||||
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
|
||||
}
|
||||
]
|
||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -1,3 +1,6 @@
|
||||
"""Module for testing DictDefault class"""
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -6,6 +9,10 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class DictDefaultTest(unittest.TestCase):
|
||||
"""
|
||||
Test DictDefault class
|
||||
"""
|
||||
|
||||
def test_dict_default(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -41,7 +48,9 @@ class DictDefaultTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = cfg | DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
|
||||
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
|
||||
)
|
||||
|
||||
assert (
|
||||
cfg.key_a.key_b == "value_b"
|
||||
@@ -73,7 +82,7 @@ class DictDefaultTest(unittest.TestCase):
|
||||
AttributeError,
|
||||
match=r"'NoneType' object has no attribute 'another_random_key'",
|
||||
):
|
||||
cfg.random_key.another_random_key
|
||||
cfg.random_key.another_random_key = "value"
|
||||
|
||||
def test_dict_shorthand_assignment(self):
|
||||
"""
|
||||
|
||||
65
tests/test_packed_dataset.py
Normal file
65
tests/test_packed_dataset.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Module for testing dataset sequence packing"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing dataset sequences
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_resets_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
dateset = load_dataset(
|
||||
"json",
|
||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
||||
)["train"]
|
||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
self.tokenizer,
|
||||
[dataset],
|
||||
seq_length=2048,
|
||||
)
|
||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
example = packed_dataset[0]
|
||||
next_bos_index = (
|
||||
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
||||
) # add one since we sliced
|
||||
|
||||
# first example doesn't have mask reset
|
||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][0] == 1
|
||||
|
||||
# but subsequent one does
|
||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][next_bos_index] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Module for testing prompt tokenizers."""
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
@@ -5,14 +6,27 @@ from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompter
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
SystemDataPrompter,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
||||
|
||||
logging.basicConfig(level="INFO")
|
||||
|
||||
|
||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
@@ -23,11 +37,15 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_sharegpt_integration(self):
|
||||
print(Path(__file__).parent)
|
||||
with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin:
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin:
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.tokenized.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
prompter = ShareGPTPrompter("chat")
|
||||
@@ -42,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
prompter = NoSystemPrompter()
|
||||
# pylint: disable=duplicate-code
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {
|
||||
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
||||
"output": "world!",
|
||||
}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
world_idx = example["input_ids"].index(3186)
|
||||
assert example["labels"][world_idx] == 3186
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
def test_alpaca(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
world_idx = example["input_ids"].index(6324)
|
||||
assert example["labels"][world_idx] == 6324
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_system_alpaca(self):
|
||||
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
||||
strat = InstructionWSystemPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {
|
||||
"system": "use cot",
|
||||
"instruction": "hello!",
|
||||
"output": "Hi! How can I help?",
|
||||
}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
|
||||
assert example["input_ids"][3] == 11889 # USER
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
"""Module testing prompters"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
PromptStyle,
|
||||
UnpromptedPrompter,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaPrompterTest(unittest.TestCase):
|
||||
"""
|
||||
Test AlpacaPrompter
|
||||
"""
|
||||
|
||||
def test_prompt_style_w_none(self):
|
||||
prompter = AlpacaPrompter(prompt_style=None)
|
||||
res = next(prompter.build_prompt("tell me a joke"))
|
||||
@@ -11,8 +23,10 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "### Instruction:" in res
|
||||
|
||||
def test_prompt_style_w_instruct(self):
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
||||
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "Below is an instruction" in res
|
||||
assert "### Instruction:" in res
|
||||
assert "### Input:" in res
|
||||
@@ -29,8 +43,10 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "ASSISTANT:" not in res
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
||||
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "Below is an instruction" in res
|
||||
assert "### Instruction:" not in res
|
||||
assert "### Input:" not in res
|
||||
@@ -46,4 +62,63 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "USER:" in res
|
||||
assert "ASSISTANT:" in res
|
||||
|
||||
def test_system_prompt(self):
|
||||
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(
|
||||
prompter.build_prompt_w_system(
|
||||
"use cot", "tell me a joke about the following", "alpacas"
|
||||
)
|
||||
)
|
||||
assert "use cot" in res
|
||||
assert res.startswith("use cot")
|
||||
assert "### Instruction:" not in res
|
||||
assert "### Input:" not in res
|
||||
assert "alpacas" in res
|
||||
assert "### Response:" not in res
|
||||
assert "USER:" in res
|
||||
assert "ASSISTANT:" in res
|
||||
|
||||
|
||||
class UnpromptedPrompterTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for UnpromptedPrompter with no system prompts
|
||||
"""
|
||||
|
||||
def test_prompt_style_w_none(self):
|
||||
prompter = UnpromptedPrompter(prompt_style=None)
|
||||
res = next(prompter.build_prompt("tell me a joke"))
|
||||
assert "### Instruction:" in res
|
||||
assert "tell me a joke" in res
|
||||
assert res.startswith("###")
|
||||
|
||||
def test_prompt_style_w_instruct(self):
|
||||
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "### Instruction:" in res
|
||||
assert "tell me a joke" in res
|
||||
assert res.startswith("###")
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "USER:" in res
|
||||
assert "tell me a joke" in res
|
||||
assert res.startswith("USER:")
|
||||
|
||||
|
||||
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for MultipleChoiceExplainPrompter
|
||||
"""
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
|
||||
assert "USER:" in res
|
||||
assert "choose one" in res
|
||||
assert "Choose the answer that best answers the question." in res
|
||||
assert "- A\n- B\n- C" in res
|
||||
|
||||
31
tests/test_tokenizers.py
Normal file
31
tests/test_tokenizers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Test cases for the tokenizer loading
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
"""
|
||||
test class for the load_tokenizer fn
|
||||
"""
|
||||
|
||||
def test_default_use_fast(self):
|
||||
cfg = DictDefault({})
|
||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||
assert "Fast" in tokenizer.__class__.__name__
|
||||
|
||||
def test_dont_use_fast(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_use_fast": False,
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
||||
assert "Fast" not in tokenizer.__class__.__name__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,12 +1,26 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.validation import validate_config
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
"""
|
||||
Test the validation module
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def test_load_4bit_deprecate(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -17,6 +31,17 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_batch_size_unused_warning(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 32,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert "batch_size is not recommended" in self._caplog.records[0].message
|
||||
|
||||
def test_qlora(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
@@ -24,7 +49,7 @@ class ValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -33,7 +58,7 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError, match=r".*8bit.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
@@ -42,7 +67,7 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError, match=r".*gptq.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": False,
|
||||
}
|
||||
@@ -51,7 +76,7 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
@@ -67,7 +92,7 @@ class ValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -76,7 +101,7 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError, match=r".*8bit.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
@@ -85,7 +110,7 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError, match=r".*gptq.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
@@ -112,3 +137,179 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
validate_config(cfg)
|
||||
|
||||
def test_gradient_accumulations_or_batch_size(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_falcon_fsdp(self):
|
||||
regex_exp = r".*FSDP is not supported for falcon models.*"
|
||||
|
||||
# Check for lower-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "tiiuae/falcon-7b",
|
||||
"fsdp": ["full_shard", "auto_wrap"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
# Check for upper-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Falcon-7b",
|
||||
"fsdp": ["full_shard", "auto_wrap"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "tiiuae/falcon-7b",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_mpt_gradient_checkpointing(self):
|
||||
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
||||
|
||||
# Check for lower-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "mosaicml/mpt-7b",
|
||||
"gradient_checkpointing": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_flash_optimum(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
"adapter": "lora",
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"probably set bfloat16 or float16" in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
regex_exp = r".*AMP is not supported.*"
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
regex_exp = r".*AMP is not supported.*"
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_adamw_hyperparams(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": None,
|
||||
"adam_epsilon": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": "adafactor",
|
||||
"adam_beta1": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"adam_beta1": 0.9,
|
||||
"adam_beta2": 0.99,
|
||||
"adam_epsilon": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": "adafactor",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user