Compare commits
437 Commits
v0.1.0
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1afbd8af2d | ||
|
|
b4f2eea2ed | ||
|
|
bbf88b02c1 | ||
|
|
64a8e04430 | ||
|
|
c8f7213bc6 | ||
|
|
b57238ecec | ||
|
|
918f1b0dfb | ||
|
|
c3fde36ada | ||
|
|
2bb0b78975 | ||
|
|
a276c9c88d | ||
|
|
7019509daa | ||
|
|
96bd6ae1c4 | ||
|
|
e37d9358e6 | ||
|
|
b5212068ac | ||
|
|
289d5c403d | ||
|
|
35c8b90306 | ||
|
|
fae6ed8092 | ||
|
|
94d03c8402 | ||
|
|
11ddccb80f | ||
|
|
964312199e | ||
|
|
718102271f | ||
|
|
f5c11f8262 | ||
|
|
fce40aab23 | ||
|
|
9c314101d5 | ||
|
|
e303d64728 | ||
|
|
b4d1d22782 | ||
|
|
9f99104038 | ||
|
|
36fefcf94b | ||
|
|
176b888a63 | ||
|
|
3392270544 | ||
|
|
bb53a165f5 | ||
|
|
10405b9995 | ||
|
|
c93655c0a3 | ||
|
|
fe285430bc | ||
|
|
0d2e34f056 | ||
|
|
b56a6c0101 | ||
|
|
2eda9e02a9 | ||
|
|
78b9efb7f4 | ||
|
|
312a9fad07 | ||
|
|
58d665943e | ||
|
|
cc7e80026e | ||
|
|
dc71d8872a | ||
|
|
248bf90f89 | ||
|
|
77085ea24e | ||
|
|
db2a3586f3 | ||
|
|
6c9a87c8ee | ||
|
|
894cba09f3 | ||
|
|
41a4d15d43 | ||
|
|
2c37bf6c21 | ||
|
|
9f69c4d8c1 | ||
|
|
3d4984b9a5 | ||
|
|
ff7f18d1ed | ||
|
|
cf62cfd661 | ||
|
|
c5df969262 | ||
|
|
40a53ff181 | ||
|
|
dcdec44347 | ||
|
|
3ffb018a4c | ||
|
|
a94f2eecb1 | ||
|
|
1066751358 | ||
|
|
1b63bf13bc | ||
|
|
5cce2a42ff | ||
|
|
2a428e8014 | ||
|
|
cdf85fdbd5 | ||
|
|
9b790d359b | ||
|
|
38811434e6 | ||
|
|
06c61d6f13 | ||
|
|
262dc29df2 | ||
|
|
165907fddb | ||
|
|
a032c9f452 | ||
|
|
b06d3e3645 | ||
|
|
c58034d48c | ||
|
|
28fd429bcf | ||
|
|
45ac7c4f88 | ||
|
|
edd6980dd9 | ||
|
|
dc6d25124d | ||
|
|
6dd2e7d671 | ||
|
|
b64f411849 | ||
|
|
03a59c1ed4 | ||
|
|
ebaec3c406 | ||
|
|
73e70e3996 | ||
|
|
d75adb9835 | ||
|
|
02224668c3 | ||
|
|
f162f3c7cc | ||
|
|
eca3531329 | ||
|
|
6f16c4569d | ||
|
|
0bd09c077d | ||
|
|
469c08c9ba | ||
|
|
334af625d0 | ||
|
|
273b3a3aa7 | ||
|
|
3cdd8e4122 | ||
|
|
cf5ae6b649 | ||
|
|
b1f4f7a34d | ||
|
|
83237b8445 | ||
|
|
46032a1a1f | ||
|
|
8bba64258e | ||
|
|
88089e8b32 | ||
|
|
168a7a09cc | ||
|
|
231031a0e1 | ||
|
|
9234b75cb4 | ||
|
|
553a86b52c | ||
|
|
5daf7d5299 | ||
|
|
5491278a79 | ||
|
|
1514739f0f | ||
|
|
896c1aebcf | ||
|
|
ef17e15483 | ||
|
|
69a235061b | ||
|
|
687d889928 | ||
|
|
c4cf567b55 | ||
|
|
c49729d2bc | ||
|
|
13ac4d8de2 | ||
|
|
19cf0bda99 | ||
|
|
f74edd5b56 | ||
|
|
d69da99c2c | ||
|
|
66afb76a15 | ||
|
|
a692ad3f4c | ||
|
|
41da98b982 | ||
|
|
9e64f42e0f | ||
|
|
b9b7d4ce92 | ||
|
|
9bed281867 | ||
|
|
e79c8e617e | ||
|
|
71456955f5 | ||
|
|
3a783c04e4 | ||
|
|
1e5014acec | ||
|
|
a10da1caff | ||
|
|
4066c78631 | ||
|
|
78a1e1fa12 | ||
|
|
bc8a2e5547 | ||
|
|
910ebe47f5 | ||
|
|
c146880a75 | ||
|
|
77bdb7d144 | ||
|
|
530809fd74 | ||
|
|
924bbfddec | ||
|
|
f150c027e3 | ||
|
|
5c39c006c9 | ||
|
|
612aabd8c4 | ||
|
|
af05883f75 | ||
|
|
05ab9092e3 | ||
|
|
7b57ed7618 | ||
|
|
3a38271276 | ||
|
|
8d20e0a3d3 | ||
|
|
de8ed229c3 | ||
|
|
478d8c7b8e | ||
|
|
645c13592c | ||
|
|
47d601fa23 | ||
|
|
756dfba97b | ||
|
|
91ab0592af | ||
|
|
0aeb7c7802 | ||
|
|
9bdd30cdfd | ||
|
|
d35278aaf1 | ||
|
|
9492d4ebb7 | ||
|
|
ad5ca4f734 | ||
|
|
cb9d3af5c0 | ||
|
|
c969f0a9dc | ||
|
|
6d0ee4ba34 | ||
|
|
a81f52d575 | ||
|
|
1925eaf1e6 | ||
|
|
1ab3bf3e67 | ||
|
|
d7635b7148 | ||
|
|
88e17ffc50 | ||
|
|
baed440fa1 | ||
|
|
7925ddce86 | ||
|
|
6f849809c5 | ||
|
|
c16644d05e | ||
|
|
945c4191a3 | ||
|
|
136522f9c9 | ||
|
|
556fe408b3 | ||
|
|
16bb6276a5 | ||
|
|
06674a11f2 | ||
|
|
3513885f43 | ||
|
|
06652c1c39 | ||
|
|
068fc48978 | ||
|
|
aaadacf6b3 | ||
|
|
5ff547dc70 | ||
|
|
dc77c8ebce | ||
|
|
51a4c12242 | ||
|
|
4b43a66a0b | ||
|
|
34ae69989f | ||
|
|
7dc580b837 | ||
|
|
fd2c9814c9 | ||
|
|
2ba4ae8f46 | ||
|
|
93dacba228 | ||
|
|
8002ffb41f | ||
|
|
74ef5cc083 | ||
|
|
5e616d91c0 | ||
|
|
94f310c7a6 | ||
|
|
8e568bbdae | ||
|
|
e21dab49fd | ||
|
|
52cde69288 | ||
|
|
9a58e99e81 | ||
|
|
c7dee56b87 | ||
|
|
aac4b7691e | ||
|
|
f31a338cbb | ||
|
|
4cd1deeef2 | ||
|
|
9ac16ed8d1 | ||
|
|
6b3f509d9e | ||
|
|
336aa3fd48 | ||
|
|
d0d7eaa4f3 | ||
|
|
a6ebf57e82 | ||
|
|
280832cec2 | ||
|
|
a43bae9ff0 | ||
|
|
effbbf6dd1 | ||
|
|
c9a149f9e8 | ||
|
|
c530e4b9c8 | ||
|
|
f620706776 | ||
|
|
77762a5d6b | ||
|
|
14668fa54e | ||
|
|
b565ecf0a1 | ||
|
|
fe0b76854e | ||
|
|
e944311442 | ||
|
|
e3e7b52a5b | ||
|
|
974dc00a7d | ||
|
|
572d1141e6 | ||
|
|
a6190c8094 | ||
|
|
563b6d89e6 | ||
|
|
cd0a6f6027 | ||
|
|
0e664a5ebc | ||
|
|
dd7d16d2eb | ||
|
|
e285e24f7f | ||
|
|
919727b4d7 | ||
|
|
5ffefee37f | ||
|
|
d9f713e4e3 | ||
|
|
958da70376 | ||
|
|
c4e4f8115c | ||
|
|
a808bf913f | ||
|
|
01248253a3 | ||
|
|
759e8673ce | ||
|
|
0c6f928601 | ||
|
|
eea2731a5e | ||
|
|
1db46a9c72 | ||
|
|
ab5cd28acf | ||
|
|
1a82082e91 | ||
|
|
1210dc8fd5 | ||
|
|
488a67d75a | ||
|
|
71a43f8479 | ||
|
|
39619028a3 | ||
|
|
8792199799 | ||
|
|
1edc30c786 | ||
|
|
14163c15d9 | ||
|
|
41e4f6ca31 | ||
|
|
79e2a6f140 | ||
|
|
c2508987a6 | ||
|
|
215d775147 | ||
|
|
f36e227eaf | ||
|
|
5878bb1f3a | ||
|
|
a03a7d7d8b | ||
|
|
fec6bcc3e6 | ||
|
|
931e606459 | ||
|
|
7f09106437 | ||
|
|
6b50200234 | ||
|
|
16f9e28048 | ||
|
|
b9083a7fc1 | ||
|
|
aefb2fc681 | ||
|
|
b5aa8d854c | ||
|
|
4d6490bce2 | ||
|
|
b242b69e10 | ||
|
|
320beb20f4 | ||
|
|
bd3b537344 | ||
|
|
813cfa4c14 | ||
|
|
2e13ceff37 | ||
|
|
2a801b001a | ||
|
|
e44c9e0b3e | ||
|
|
55b8542de8 | ||
|
|
febe902517 | ||
|
|
f4df266842 | ||
|
|
281dc3df59 | ||
|
|
2ef4634d45 | ||
|
|
7eae90333e | ||
|
|
c8242de725 | ||
|
|
2cfe9e9b16 | ||
|
|
79a8f52181 | ||
|
|
afaa0d2c01 | ||
|
|
bfd27ba55e | ||
|
|
babf0fdb71 | ||
|
|
a52f4816b0 | ||
|
|
81911d112c | ||
|
|
52765ac588 | ||
|
|
73e9ea4069 | ||
|
|
f8d379883d | ||
|
|
04a1b77307 | ||
|
|
2097a09d2d | ||
|
|
cfff94b123 | ||
|
|
2b222de5b6 | ||
|
|
df9528f865 | ||
|
|
193c73bce0 | ||
|
|
6abfd87d44 | ||
|
|
59bb2197ed | ||
|
|
9a02e7e1ff | ||
|
|
5b33e295bd | ||
|
|
4ac9e251b7 | ||
|
|
c9c050316f | ||
|
|
ca11ae9689 | ||
|
|
328c3bce96 | ||
|
|
5cd2126439 | ||
|
|
12620f3089 | ||
|
|
4ab0c8b201 | ||
|
|
74ebbf4371 | ||
|
|
76a70fd739 | ||
|
|
618816d4df | ||
|
|
91992cb8f5 | ||
|
|
84169d15b3 | ||
|
|
ecfe8d0a1a | ||
|
|
eee44a3b47 | ||
|
|
078a43eef8 | ||
|
|
33e1890086 | ||
|
|
1c38253692 | ||
|
|
496b83f778 | ||
|
|
ff68a95781 | ||
|
|
fb3d40f197 | ||
|
|
288fd62431 | ||
|
|
3c71c8debe | ||
|
|
a6f5e5eaec | ||
|
|
5a631b305b | ||
|
|
f94dd626f0 | ||
|
|
5079753b7a | ||
|
|
0136f510f2 | ||
|
|
72bf8aafb6 | ||
|
|
8afb0fbaba | ||
|
|
9b8585dc70 | ||
|
|
8eb5811d4e | ||
|
|
e0011fdf55 | ||
|
|
6e9e98720e | ||
|
|
c2a0792680 | ||
|
|
b267d24a2b | ||
|
|
5c3f5db38b | ||
|
|
e3d03745ba | ||
|
|
fac46002d4 | ||
|
|
33d40179ba | ||
|
|
dcb03d6da4 | ||
|
|
0e4be625ae | ||
|
|
bdc4bd7d4e | ||
|
|
2d0ba3b818 | ||
|
|
c7021e191f | ||
|
|
c56818b119 | ||
|
|
2675fb756e | ||
|
|
1076bcbbca | ||
|
|
2daa6835f0 | ||
|
|
e3c494ca7b | ||
|
|
ad0ea6aaab | ||
|
|
876edd83d0 | ||
|
|
6cb2310592 | ||
|
|
6fa40bf8ad | ||
|
|
3aad5f3b3e | ||
|
|
39a208c2bc | ||
|
|
2520ecd6df | ||
|
|
c5b0af1a7e | ||
|
|
988aeb9c34 | ||
|
|
cf61f14bff | ||
|
|
0abcd71a85 | ||
|
|
c43c5c84ff | ||
|
|
36ec6e1a0e | ||
|
|
13b80937f9 | ||
|
|
bbc5bc5791 | ||
|
|
4df9da74e3 | ||
|
|
2531ea24c1 | ||
|
|
01a75fd027 | ||
|
|
b81c97ff76 | ||
|
|
594e72b6e8 | ||
|
|
25eeeeba0b | ||
|
|
cfcc549f6b | ||
|
|
a1f9850b91 | ||
|
|
83d29209f7 | ||
|
|
d011422200 | ||
|
|
b1cc54b14a | ||
|
|
c17dae6d07 | ||
|
|
37293dce07 | ||
|
|
96e8378692 | ||
|
|
e9650d3ae4 | ||
|
|
f1232b35ba | ||
|
|
741a3f2edc | ||
|
|
0dd35c74af | ||
|
|
db288e9b13 | ||
|
|
be22551435 | ||
|
|
b832a0ac62 | ||
|
|
afb31e13a3 | ||
|
|
1bf1f59a41 | ||
|
|
8e46c0fb0d | ||
|
|
1f3c3f5ea0 | ||
|
|
0e952889dc | ||
|
|
9c6750a075 | ||
|
|
c2dbf2c526 | ||
|
|
e6b57decbd | ||
|
|
fe1f4c4e7d | ||
|
|
dae14e5951 | ||
|
|
633ff2150f | ||
|
|
5d86137f70 | ||
|
|
01c8a333b3 | ||
|
|
7eb33a77dd | ||
|
|
1645a4ddd5 | ||
|
|
145b060cbe | ||
|
|
8cc0aadcb8 | ||
|
|
6abb7f6a16 | ||
|
|
de2406c488 | ||
|
|
8b617cc7f6 | ||
|
|
ddb86ea821 | ||
|
|
1a2bd7ff62 | ||
|
|
82971e1565 | ||
|
|
f4e5d86268 | ||
|
|
daf47ccf45 | ||
|
|
545cfeb5c7 | ||
|
|
69722aeef4 | ||
|
|
5658717dbd | ||
|
|
e8717d3bef | ||
|
|
54c3b5b25f | ||
|
|
5062eca069 | ||
|
|
cb4f0e9342 | ||
|
|
4c0eddb3f8 | ||
|
|
1c60c10e00 | ||
|
|
903ea3080d | ||
|
|
cb7cd3429f | ||
|
|
d57ba56746 | ||
|
|
c3a4697016 | ||
|
|
392dfd9b07 | ||
|
|
a98deb31a6 | ||
|
|
36596adaf7 | ||
|
|
6cee881d64 | ||
|
|
48612f8376 | ||
|
|
d91a769b88 | ||
|
|
6ef96f569b | ||
|
|
ac85c0ed36 | ||
|
|
f1fbf666f7 | ||
|
|
370d057096 | ||
|
|
e0ccaccce2 | ||
|
|
15e57ba6ee | ||
|
|
4eb68ac3f7 | ||
|
|
b6a539b53c | ||
|
|
abddcf4dfe | ||
|
|
15aabd2903 | ||
|
|
232b931081 | ||
|
|
0736f4f9c1 | ||
|
|
d77d736631 | ||
|
|
fad06befee | ||
|
|
2aacf75ee1 | ||
|
|
71871345a6 | ||
|
|
0d14e951a8 | ||
|
|
84fc217f79 | ||
|
|
f317296259 | ||
|
|
42a971df32 |
5
.flake8
Normal file
5
.flake8
Normal file
@@ -0,0 +1,5 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
|
||||
select = C,E,F,W,B,B950
|
||||
extend-ignore = E203, E501, W503
|
||||
31
.github/release-drafter.yml
vendored
Normal file
31
.github/release-drafter.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name-template: 'v$RESOLVED_VERSION'
|
||||
tag-template: 'v$RESOLVED_VERSION'
|
||||
categories:
|
||||
- title: '🚀 Features'
|
||||
labels:
|
||||
- 'feature'
|
||||
- 'enhancement'
|
||||
- title: '🐛 Bug Fixes'
|
||||
labels:
|
||||
- 'fix'
|
||||
- 'bugfix'
|
||||
- 'bug'
|
||||
- title: '🧰 Maintenance'
|
||||
label: 'chore'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.
|
||||
version-resolver:
|
||||
major:
|
||||
labels:
|
||||
- 'major'
|
||||
minor:
|
||||
labels:
|
||||
- 'minor'
|
||||
patch:
|
||||
labels:
|
||||
- 'patch'
|
||||
default: patch
|
||||
template: |
|
||||
## What’s Changed
|
||||
|
||||
$CHANGES
|
||||
24
.github/workflows/base.yml
vendored
24
.github/workflows/base.yml
vendored
@@ -12,16 +12,19 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: self-hosted
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
cuda_version_bnb: "118"
|
||||
pytorch: 2.0.0
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
cuda_version_bnb: "117"
|
||||
pytorch: 1.13.1
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -43,12 +46,11 @@ jobs:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
CUDA_VERSION_BNB=${{ matrix.cuda_version_bnb }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTHON_VERSION=${{ matrix.python_version }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||
|
||||
50
.github/workflows/main.yml
vendored
50
.github/workflows/main.yml
vendored
@@ -11,14 +11,24 @@ jobs:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
pytorch: 2.0.0
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
pytorch: 1.13.1
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -40,13 +50,11 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
@@ -54,12 +62,21 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
pytorch: 2.0.0
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
pytorch: 1.13.1
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -81,10 +98,9 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-runpod
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
16
.github/workflows/pre-commit.yml
vendored
Normal file
16
.github/workflows/pre-commit.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -7,6 +7,7 @@ jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.9", "3.10"]
|
||||
timeout-minutes: 10
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -160,4 +160,4 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
.idea/
|
||||
|
||||
2
.isort.cfg
Normal file
2
.isort.cfg
Normal file
@@ -0,0 +1,2 @@
|
||||
[settings]
|
||||
profile=black
|
||||
39
.mypy.ini
Normal file
39
.mypy.ini
Normal file
@@ -0,0 +1,39 @@
|
||||
[mypy]
|
||||
|
||||
exclude = venv
|
||||
|
||||
[mypy-alpaca_lora_4bit.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-flash_attn.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-huggingface_hub]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-transformers.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-peft]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-datasets]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-fire]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-setuptools]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-addict]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-xformers.*]
|
||||
ignore_missing_imports = True
|
||||
42
.pre-commit-config.yaml
Normal file
42
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
default_language_version:
|
||||
python: python3
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
rev: v2.17.4
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
[
|
||||
'types-PyYAML',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.5
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
'--ini',
|
||||
'.bandit',
|
||||
]
|
||||
14
.pylintrc
Normal file
14
.pylintrc
Normal file
@@ -0,0 +1,14 @@
|
||||
[MASTER]
|
||||
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=numpy.*, torch.*
|
||||
|
||||
|
||||
[pylint.messages_control]
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
3
FAQS.md
3
FAQS.md
@@ -2,3 +2,6 @@
|
||||
|
||||
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
|
||||
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
|
||||
- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
|
||||
`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
|
||||
This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
|
||||
|
||||
202
LICENSE
Normal file
202
LICENSE
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
358
README.md
358
README.md
@@ -9,36 +9,39 @@
|
||||
<p>
|
||||
Go ahead and axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
|
||||
|---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❓ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ❓ | ✅
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
**Requirements**: Python 3.9.
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
|
||||
pip3 install -e .[int4]
|
||||
|
||||
accelerate config
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
# finetune lora
|
||||
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
--inference --lora_model_dir="./lora-out"
|
||||
```
|
||||
|
||||
@@ -48,18 +51,82 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
```sh
|
||||
docker compose up -d
|
||||
```
|
||||
- `winglian/axolotl:dev`: dev branch
|
||||
- `winglian/axolotl-runpod:main`: for runpod
|
||||
|
||||
- Conda/Pip venv
|
||||
1. Install python **3.9**
|
||||
|
||||
2. Install python dependencies with ONE of the following:
|
||||
- `pip3 install -e .[int4]` (recommended)
|
||||
- `pip3 install -e .[int4_triton]`
|
||||
- `pip3 install -e .`
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||
```bash
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
- gptq/int4 support, NO QLoRA
|
||||
```bash
|
||||
pip3 install -e .[gptq]
|
||||
```
|
||||
- same as above but not recommended
|
||||
```bash
|
||||
pip3 install -e .[gptq_triton]
|
||||
```
|
||||
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
1. Install python
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y python3.9
|
||||
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
||||
sudo update-alternatives --config python # pick 3.9 if given option
|
||||
python -V # should be 3.9
|
||||
|
||||
```
|
||||
|
||||
2. Install pip
|
||||
```bash
|
||||
wget https://bootstrap.pypa.io/get-pip.py
|
||||
python get-pip.py
|
||||
```
|
||||
|
||||
3. Install torch
|
||||
```bash
|
||||
pip3 install -U torch --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
4. Axolotl
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install -e . # change depend on needs
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U requests
|
||||
pip3 install -U --ignore-installed psutil
|
||||
pip3 install -U scipy
|
||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
||||
```
|
||||
|
||||
5. Set path
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||
```
|
||||
</details>
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -69,7 +136,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations
|
||||
- `sharegpt:chat`: conversations
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -110,16 +177,73 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"article": "...", "summary": "..."}
|
||||
```
|
||||
|
||||
> Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch!
|
||||
- `alpaca_chat`: basic instruct for alpaca chat
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "response": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_qa`: question and answer for alpaca chat
|
||||
```json
|
||||
{"question": "...", "answer": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "response": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai
|
||||
```json
|
||||
{"message_1": "...", "message_2": "..."}
|
||||
```
|
||||
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
|
||||
```json
|
||||
{"system_prompt": "...", "question": "...", "response": "..."}
|
||||
```
|
||||
- `context_qa`: in context question answering from an article
|
||||
```json
|
||||
{"article": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||
```json
|
||||
{"article": "...", "unanswerable_question": "..."}
|
||||
```
|
||||
- `creative_acr.load_answer`: instruction and revision
|
||||
```json
|
||||
{"instruction": "...", "revision": "..."}
|
||||
```
|
||||
- `creative_acr.load_critique`: critique
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "..."}
|
||||
```
|
||||
- `creative_acr.load_revise`: critique and revise
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
|
||||
```
|
||||
- `pygmalion`: pygmalion
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
```json
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||
|
||||
|
||||
|
||||
### Config
|
||||
|
||||
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||
|
||||
- model
|
||||
```yaml
|
||||
@@ -129,10 +253,24 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
||||
|
||||
- dataset
|
||||
```yaml
|
||||
sequence_len: 2048 # max token length for prompt
|
||||
|
||||
# huggingface repo
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4 # local or huggingface repo
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca # format from earlier
|
||||
|
||||
# huggingface repo with specific configuration/subset
|
||||
datasets:
|
||||
- path: EleutherAI/pile
|
||||
name: enron_emails
|
||||
type: completion # format from earlier
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: json
|
||||
data_files: data.jsonl # or json
|
||||
type: alpaca # format from earlier
|
||||
sequence_len: 2048 # max token length / prompt
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -142,6 +280,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
||||
bf16: true # require >=ampere
|
||||
fp16: true
|
||||
tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
||||
float16: true # use instead of fp16 when you don't want AMP
|
||||
```
|
||||
Note: Repo does not do 4-bit quantization.
|
||||
|
||||
@@ -169,12 +309,22 @@ base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Trust remote code for untrusted source
|
||||
trust_remote_code:
|
||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||
tokenizer_use_fast:
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
@@ -195,18 +345,21 @@ tf32: true # require >=ampere
|
||||
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# this can be either a hf dataset, or relative path
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format OR format:prompt_style (chat/instruct)
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
data_files: # path to source data files
|
||||
shards: # number of shards to split data into
|
||||
name: # name of dataset configuration to load
|
||||
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path to push finetuned model
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
@@ -222,7 +375,14 @@ dataset_shard_idx:
|
||||
sequence_len: 2048
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
@@ -248,30 +408,39 @@ lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: # 'checkpoint'
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
logging_steps:
|
||||
save_steps:
|
||||
eval_steps:
|
||||
save_total_limit:
|
||||
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
group_by_length: false
|
||||
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
@@ -293,11 +462,31 @@ log_sweep_max_lr:
|
||||
optimizer:
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention: # require a100 for llama
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -329,6 +518,9 @@ torchdistx_path:
|
||||
# Set padding for data collator to 'longest'
|
||||
collator_pad_to_longest:
|
||||
|
||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
pretraining_dataset:
|
||||
|
||||
# Debug mode
|
||||
debug:
|
||||
|
||||
@@ -341,17 +533,6 @@ strict:
|
||||
|
||||
</details>
|
||||
|
||||
### Accelerate
|
||||
|
||||
Configure accelerate
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
|
||||
# Edit manually
|
||||
# nano ~/.cache/huggingface/accelerate/default_config.yaml
|
||||
```
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
@@ -359,17 +540,56 @@ Run
|
||||
accelerate launch scripts/finetune.py configs/your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
||||
```
|
||||
|
||||
##### Config
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_offload_params: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
- llama Deepspeed: append `ACCELERATE_USE_DEEPSPEED=true` in front of finetune command
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
--inference --lora_model_dir ./completed-model
|
||||
--inference --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
--inference --base_model ./completed-model
|
||||
--inference --base_model="./completed-model"
|
||||
```
|
||||
- Full weights finetune w/ a prompt from a text file:
|
||||
```bash
|
||||
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
||||
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
||||
```
|
||||
|
||||
### Merge LORA to base
|
||||
@@ -380,6 +600,12 @@ Add below flag to train command above
|
||||
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
||||
```
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
> Cuda out of memory
|
||||
@@ -387,6 +613,7 @@ Add below flag to train command above
|
||||
Please reduce any below
|
||||
- `micro_batch_size`
|
||||
- `eval_batch_size`
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
> RuntimeError: expected scalar type Float but found Half
|
||||
@@ -397,12 +624,45 @@ Try set `fp16: true`
|
||||
|
||||
Try to turn off xformers.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
> accelerate config missing
|
||||
|
||||
It's safe to ignore it.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||
|
||||
```markdown
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
```
|
||||
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
|
||||
## Community Showcase
|
||||
|
||||
Open Access AI Collective
|
||||
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
|
||||
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
||||
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
||||
|
||||
PocketDoc Labs
|
||||
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
||||
|
||||
## Contributing 🤝
|
||||
|
||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||
|
||||
PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: 'NO'
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,40 +0,0 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- c_attn
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: pythia-1.4b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 32
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.0003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,41 +0,0 @@
|
||||
base_model: facebook/galactica-1.3b
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 32
|
||||
micro_batch_size: 16
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
tf32: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,39 +0,0 @@
|
||||
base_model: EleutherAI/gpt-neox-20b
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: nomic-ai/gpt4all-j-prompt-generations
|
||||
type: alpaca
|
||||
shards: 4
|
||||
shards_index: 0
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- query_key_value
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project: gpt4all-neox-20b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./gpt4all-neox-20b
|
||||
batch_size: 48
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
lr_scheduler: one_cycle
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,39 +0,0 @@
|
||||
base_model: huggyllama/llama-13b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
||||
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
||||
type: sharegpt
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.002
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./llama-13b-sharegpt
|
||||
batch_size: 64
|
||||
micro_batch_size: 2
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
eval_steps:
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience: 5
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,44 +0,0 @@
|
||||
base_model: huggyllama/llama-65b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
||||
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-65b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
micro_batch_size: 16
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca # original alpaca dataset
|
||||
type: alpaca
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-test
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
load_4bit: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
@@ -1,41 +0,0 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
micro_batch_size: 16
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca # original alpaca dataset
|
||||
type: alpaca
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-test
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
gptq: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
@@ -1,86 +0,0 @@
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# whether you are training a 4-bit quantized model
|
||||
load_4bit: true
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# this can be either a hf dataset, or relative path
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
|
||||
val_set_size: 0.04
|
||||
# if you want to use lora, leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
lora_model_dir:
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
max_packed_sequence_len: 1024
|
||||
# lora hyperparameters
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
# wandb configuration if your're using it
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
# where to save the finsihed model to
|
||||
output_dir: ./completed-model
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
# Use CUDA bf16
|
||||
bf16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
gradient_checkpointing: false
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
# specify a scheduler to use with the optimizer. only one_cycle is supported currently
|
||||
lr_scheduler:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention:
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
@@ -1,56 +0,0 @@
|
||||
base_model: stabilityai/stablelm-base-alpha-3b
|
||||
base_model_config: stabilityai/stablelm-base-alpha-3b
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: stable-alpaca-3b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./stable-alpaca-3b
|
||||
batch_size: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 100
|
||||
eval_steps: 50
|
||||
save_steps: 200
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.01
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
#tokens:
|
||||
# pad_token: "[PAD]"
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
||||
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_4bit: true
|
||||
gptq_groupsize: 128
|
||||
gptq_model_v1: false
|
||||
datasets:
|
||||
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
|
||||
- path: data/alpaca_reflect_pruned.jsonl
|
||||
type: reflection
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-reflect
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
flash_attention: true
|
||||
@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
|
||||
## Convert the JSON data files to JSONL.
|
||||
|
||||
```shell
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
```
|
||||
---
|
||||
|
||||
|
||||
@@ -37,18 +37,18 @@
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
0.95
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "OneCycle",
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"cycle_min_lr": 0.00001,
|
||||
"cycle_max_lr": 0.00003,
|
||||
"cycle_first_step_size": 120
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"train_batch_size": "auto",
|
||||
20
docker-compose.yaml
Normal file
20
docker-compose.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
# version: '3.8'
|
||||
services:
|
||||
axolotl:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./docker/Dockerfile
|
||||
volumes:
|
||||
- .:/workspace/axolotl
|
||||
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
||||
# set environment variables
|
||||
environment:
|
||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# count: 1
|
||||
capabilities: [gpu]
|
||||
command: tail -f /dev/null
|
||||
@@ -2,19 +2,29 @@ ARG BASE_TAG=main-base
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
RUN python3 -m pip install -U --no-cache-dir pydantic
|
||||
|
||||
RUN mkdir axolotl
|
||||
COPY . axolotl/
|
||||
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main"
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN cd axolotl && \
|
||||
pip install -e .[int4]
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN cd axolotl && \
|
||||
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
|
||||
@@ -8,8 +8,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH="2.0.0"
|
||||
ARG CUDA="cu118"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG CUDA="118"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
|
||||
@@ -29,17 +29,18 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
|
||||
FROM base-builder AS flash-attn-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||
cd flash-attention && \
|
||||
git checkout v2.0.1 && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd csrc/fused_dense_lib && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
@@ -52,6 +53,8 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
@@ -61,21 +64,24 @@ RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA_VERSION_BNB="118"
|
||||
ENV CUDA_VERSION_BNB=$CUDA_VERSION_BNB
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA_VERSION_BNB make cuda11x && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
# recompile apex
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
||||
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .
|
||||
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
@@ -93,10 +99,6 @@ COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/d
|
||||
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" && \
|
||||
pip3 install awscli && \
|
||||
RUN pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic
|
||||
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
|
||||
61
examples/cerebras/qlora.yml
Normal file
61
examples/cerebras/qlora.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- c_fc
|
||||
- c_attn
|
||||
- c_proj
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -23,7 +23,8 @@ lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: falcon-7b
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -61,4 +62,3 @@ special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
|
||||
93
examples/falcon/config-7b-qlora.yml
Normal file
93
examples/falcon/config-7b-qlora.yml
Normal file
@@ -0,0 +1,93 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
gptq: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: QingyiSi/Alpaca-CoT
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
# hyperparameters from QLoRA paper Appendix B.2
|
||||
# "We find hyperparameters to be largely robust across datasets"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
# 0.1 for models up to 13B
|
||||
# 0.05 for 33B and 65B models
|
||||
lora_dropout: 0.05
|
||||
# add LoRA modules on all linear layers of the base model
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
# QLoRA paper Table 9
|
||||
# - 16 for 7b & 13b
|
||||
# - 32 for 33b, 64 for 64b
|
||||
# Max size tested on A6000
|
||||
# - 7b: 40
|
||||
# - 40b: 4
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
# QLoRA paper Table 9
|
||||
# - 2e-4 for 7b & 13b
|
||||
# - 1e-4 for 33b & 64b
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
@@ -23,7 +23,8 @@ lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: falcon-7b
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -61,4 +62,3 @@ special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
|
||||
58
examples/gptj/qlora.yml
Normal file
58
examples/gptj/qlora.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -3,6 +3,6 @@
|
||||
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
|
||||
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
|
||||
|
||||
```
|
||||
|
||||
@@ -22,11 +22,12 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora-int4
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-7b-lora-int4
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
@@ -7,30 +7,29 @@ datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
type: jeopardy
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
sequence_len: 512
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: jeopardy-bot-7b
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
@@ -48,11 +47,10 @@ eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
20
examples/llama-2/README.md
Normal file
20
examples/llama-2/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Overview
|
||||
|
||||
This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.
|
||||
|
||||
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
|
||||
|
||||
The 13b variant will fit if you change these settings to these values:
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
||||
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
||||
|
||||
```
|
||||
67
examples/llama-2/lora.yml
Normal file
67
examples/llama-2/lora.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
pad_token: "<pad>"
|
||||
68
examples/llama-2/qlora.yml
Normal file
68
examples/llama-2/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
pad_token: "<pad>"
|
||||
@@ -20,11 +20,12 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
16
examples/openllama-3b/README.md
Normal file
16
examples/openllama-3b/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# openllama-3b
|
||||
|
||||
Basic full tune
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
|
||||
```
|
||||
|
||||
LoRA
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
QLoRA
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
|
||||
```
|
||||
63
examples/openllama-3b/config.yml
Normal file
63
examples/openllama-3b/config.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
float16: true
|
||||
bf16: false
|
||||
fp16: false
|
||||
tf32: false
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model_config: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -28,6 +28,7 @@ lora_target_modules:
|
||||
- o_proj
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
@@ -49,7 +50,7 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
62
examples/openllama-3b/qlora.yml
Normal file
62
examples/openllama-3b/qlora.yml
Normal file
@@ -0,0 +1,62 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
9
examples/pythia-12b/README.md
Normal file
9
examples/pythia-12b/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Pythia 12B
|
||||
|
||||
- Single-GPU A100 only (?)
|
||||
|
||||
```shell
|
||||
python scripts/finetune.py examples/pythia-12b/config.yml
|
||||
```
|
||||
|
||||
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
|
||||
50
examples/pythia-12b/config.yml
Normal file
50
examples/pythia-12b/config.yml
Normal file
@@ -0,0 +1,50 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
device_map: auto
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
fp16: false
|
||||
float16: true
|
||||
tf32: true
|
||||
flash_optimum: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
gradient_checkpointing: true
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
collator_pad_to_longest: true
|
||||
@@ -1,36 +1,30 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
sequence_len: 512
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- query_key_value
|
||||
# - xxx
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project: pythia-1.4b-lora
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 48
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
@@ -39,3 +33,6 @@ tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 20
|
||||
logging_steps: 1
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: GPTNeoXTokenizer
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
@@ -21,9 +21,10 @@ lora_target_modules:
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model: checkpoint
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
|
||||
@@ -20,6 +20,7 @@ lora_target_modules:
|
||||
- mlp_down
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
91
examples/xgen-7b/xgen-7b-8k-qlora.yml
Normal file
91
examples/xgen-7b/xgen-7b-8k-qlora.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||
# on Tim Dettmer's Guanaco dataset.
|
||||
base_model: Salesforce/xgen-7b-8k-base
|
||||
base_model_config: Salesforce/xgen-7b-8k-base
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
gptq: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: timdettmers/openassistant-guanaco
|
||||
data_files:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 8192
|
||||
max_packed_sequence_len:
|
||||
|
||||
# hyperparameters from QLoRA paper Appendix B.2
|
||||
# "We find hyperparameters to be largely robust across datasets"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
# 0.1 for models up to 13B
|
||||
# 0.05 for 33B and 65B models
|
||||
lora_dropout: 0.05
|
||||
# add LoRA modules on all linear layers of the base model
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
# QLoRA paper Table 9
|
||||
# - 16 for 7b & 13b
|
||||
# - 32 for 33b, 64 for 64b
|
||||
# Max size tested on A6000
|
||||
# - 7b: 40
|
||||
# - 40b: 4
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
# QLoRA paper Table 9
|
||||
# - 2e-4 for 7b & 13b
|
||||
# - 1e-4 for 33b & 64b
|
||||
learning_rate: 0.00002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
gradient_checkpointing: true
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
eos_token: "<|endoftext|>"
|
||||
bos_token: "<|endoftext|>"
|
||||
unk_token: "<|endoftext|>"
|
||||
pad_token: "<|endoftext|>"
|
||||
BIN
image/axolotl-badge-web.png
Normal file
BIN
image/axolotl-badge-web.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
3
requirements-dev.txt
Normal file
3
requirements-dev.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pre-commit
|
||||
black
|
||||
mypy
|
||||
@@ -1,19 +1,24 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.39.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
addict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
black
|
||||
datasets
|
||||
accelerate>=0.19.0
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
numba
|
||||
numpy==1.24.4
|
||||
# qlora things
|
||||
bert-score==0.3.13
|
||||
evaluate==0.4.0
|
||||
rouge-score==0.1.2
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
|
||||
@@ -1,24 +1,41 @@
|
||||
"""Module to convert json file to jsonl"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import fire
|
||||
from typing import Optional
|
||||
|
||||
from axolotl.convert import (
|
||||
FileReader,
|
||||
FileWriter,
|
||||
JsonlSerializer,
|
||||
JsonParser,
|
||||
JsonToJsonlConverter,
|
||||
StdoutWriter,
|
||||
)
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
configure_logging()
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.convert import *
|
||||
|
||||
|
||||
def main(
|
||||
input: Path,
|
||||
file: Path,
|
||||
output: Optional[Path] = None,
|
||||
to_stdout: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Convert a json file to jsonl
|
||||
"""
|
||||
|
||||
file_reader = FileReader()
|
||||
writer: Union[StdoutWriter, FileWriter]
|
||||
if to_stdout or output is None:
|
||||
writer = StdoutWriter()
|
||||
else:
|
||||
@@ -28,7 +45,7 @@ def main(
|
||||
|
||||
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||
|
||||
converter.convert(input, output)
|
||||
converter.convert(file, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -5,94 +7,140 @@ import random
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.validation import validate_config
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
setup_trainer,
|
||||
)
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
|
||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def choose_device(cfg):
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
return "cpu"
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device == "cuda":
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to finish): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||
return instruction
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=100,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = [file for file in path.glob("*.yml")]
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
if not yaml_files:
|
||||
raise ValueError(
|
||||
@@ -130,31 +178,37 @@ def train(
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, "r") as f:
|
||||
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k in kwargs:
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or cfg.strict is False:
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.gradient_accumulation_steps = (
|
||||
cfg.gradient_accumulation_steps // cfg.world_size
|
||||
)
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
@@ -163,88 +217,126 @@ def train(
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
validate_config(cfg)
|
||||
if cfg.tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# load the tokenizer first
|
||||
logging.info("loading tokenizer...")
|
||||
tokenizer = load_tokenizer(
|
||||
cfg.base_model_config,
|
||||
cfg.tokenizer_type,
|
||||
cfg
|
||||
)
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
|
||||
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
): # don't need to load dataset for these
|
||||
if not cfg.pretraining_dataset:
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
train_dataset = load_pretraining_dataset(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if is_main_process():
|
||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
barrier()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
logging.info("check_dataset_labels...")
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
||||
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
||||
),
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
if prepare_ds_only:
|
||||
logging.info("Finished preparing dataset. Exiting...")
|
||||
LOG.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
# Load the model and tokenizer
|
||||
logging.info("loading model and peft_config...")
|
||||
model, peft_config = load_model(
|
||||
cfg.base_model,
|
||||
cfg.base_model_config,
|
||||
cfg.model_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
inference=("inference" in kwargs),
|
||||
)
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
logging.info("running merge of LoRA with base model")
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
logging.info("saving merged model")
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
LOG.info("saving merged model")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if "inference" in kwargs:
|
||||
logging.info("calling do_inference function")
|
||||
do_inference(cfg, model, tokenizer)
|
||||
if cfg.inference:
|
||||
LOG.info("calling do_inference function")
|
||||
prompter: Optional[str] = "AlpacaPrompter"
|
||||
if "prompter" in kwargs:
|
||||
if kwargs["prompter"] == "None":
|
||||
prompter = None
|
||||
else:
|
||||
prompter = kwargs["prompter"]
|
||||
do_inference(cfg, model, tokenizer, prompter=prompter)
|
||||
return
|
||||
|
||||
if "shard" in kwargs:
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
return
|
||||
|
||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
logging.info("Compiling torch model")
|
||||
LOG.info("Compiling torch model")
|
||||
model = torch.compile(model)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
signal.SIGINT,
|
||||
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
logging.info("Starting trainer...")
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
logging.info("hang tight... sorting dataset for group_by_length")
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
@@ -252,20 +344,39 @@ def train(
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
resume_from_checkpoint = sorted_paths[-1]
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.adapter == "lora" and cfg.relora_steps:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
|
||||
|
||||
19
scripts/runpod-entrypoint.sh
Normal file → Executable file
19
scripts/runpod-entrypoint.sh
Normal file → Executable file
@@ -1,10 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
|
||||
chmod 700 -R ~/.ssh
|
||||
# Export specific ENV variables to /etc/rp_environment
|
||||
echo "Exporting environment variables..."
|
||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||
|
||||
# Start the SSH service in the background
|
||||
service ssh start
|
||||
if [[ $PUBLIC_KEY ]]
|
||||
then
|
||||
mkdir -p ~/.ssh
|
||||
chmod 700 ~/.ssh
|
||||
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
|
||||
chmod 700 -R ~/.ssh
|
||||
# Start the SSH service in the background
|
||||
service ssh start
|
||||
else
|
||||
echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon"
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
exec "$@"
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
export WANDB_MODE=offline
|
||||
export WANDB_CACHE_DIR=/workspace/data/wandb-cache
|
||||
mkdir -p $WANDB_CACHE_DIR
|
||||
|
||||
mkdir -p /workspace/data/huggingface-cache/{hub,datasets}
|
||||
export HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
export HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
export TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
export NCCL_P2P_DISABLE=1
|
||||
|
||||
nvidia-smi
|
||||
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
gpu_indices=$(seq 0 $((num_gpus - 1)) | paste -sd "," -)
|
||||
export CUDA_VISIBLE_DEVICES=$gpu_indices
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
apt-get update
|
||||
apt-get install -y build-essential ninja-build vim git-lfs
|
||||
git lfs install
|
||||
pip3 install --force-reinstall https://download.pytorch.org/whl/nightly/cu117/torch-2.0.0.dev20230301%2Bcu117-cp38-cp38-linux_x86_64.whl --index-url https://download.pytorch.org/whl/nightly/cu117
|
||||
if [ -z "${TORCH_CUDA_ARCH_LIST}" ]; then # only set this if not set yet
|
||||
# this covers most common GPUs that the installed version of pytorch supports
|
||||
# python -c "import torch; print(torch.cuda.get_arch_list())"
|
||||
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
fi
|
||||
|
||||
# install flash-attn and deepspeed from pre-built wheels for this specific container b/c these take forever to install
|
||||
mkdir -p /workspace/wheels
|
||||
cd /workspace/wheels
|
||||
curl -L -O https://github.com/OpenAccess-AI-Collective/axolotl/raw/wheels/wheels/deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
|
||||
curl -L -O https://github.com/OpenAccess-AI-Collective/axolotl/raw/wheels/wheels/flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
|
||||
pip install deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
|
||||
pip install flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
|
||||
pip install "peft @ git+https://github.com/huggingface/peft.git@main" --force-reinstall --no-dependencies
|
||||
|
||||
cd /workspace/
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
cd axolotl
|
||||
pip install -e .[int4]
|
||||
mkdir -p ~/.cache/huggingface/accelerate/
|
||||
cp configs/accelerate/default_config.yaml ~/.cache/huggingface/accelerate/default_config.yaml
|
||||
10
setup.py
10
setup.py
@@ -1,7 +1,9 @@
|
||||
from setuptools import setup, find_packages
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
install_requires = []
|
||||
with open("./requirements.txt", "r") as requirements_file:
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
# don't include peft yet until we check the int4
|
||||
# need to manually install peft for now...
|
||||
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||
@@ -17,10 +19,10 @@ setup(
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
"int4": [
|
||||
"gptq": [
|
||||
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"int4_triton": [
|
||||
"gptq_triton": [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"extras": [
|
||||
|
||||
@@ -1,47 +1,76 @@
|
||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
class FileReader:
|
||||
"""
|
||||
Reads a file and returns its contents as a string
|
||||
"""
|
||||
|
||||
def read(self, file_path):
|
||||
with open(file_path, "r") as file:
|
||||
with open(file_path, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
class FileWriter:
|
||||
"""
|
||||
Writes a string to a file
|
||||
"""
|
||||
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
|
||||
def write(self, content):
|
||||
with open(self.file_path, "w") as file:
|
||||
with open(self.file_path, "w", encoding="utf-8") as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
class StdoutWriter:
|
||||
"""
|
||||
Writes a string to stdout
|
||||
"""
|
||||
|
||||
def write(self, content):
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
class JsonParser:
|
||||
"""
|
||||
Parses a string as JSON and returns the result
|
||||
"""
|
||||
|
||||
def parse(self, content):
|
||||
return json.loads(content)
|
||||
|
||||
|
||||
class JsonlSerializer:
|
||||
"""
|
||||
Serializes a list of JSON objects into a JSONL string
|
||||
"""
|
||||
|
||||
def serialize(self, data):
|
||||
lines = [json.dumps(item) for item in data]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class JsonToJsonlConverter:
|
||||
"""
|
||||
Converts a JSON file to JSONL
|
||||
"""
|
||||
|
||||
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
|
||||
self.file_reader = file_reader
|
||||
self.file_writer = file_writer
|
||||
self.json_parser = json_parser
|
||||
self.jsonl_serializer = jsonl_serializer
|
||||
|
||||
def convert(self, input_file_path, output_file_path):
|
||||
def convert(
|
||||
self, input_file_path, output_file_path
|
||||
): # pylint: disable=unused-argument
|
||||
content = self.file_reader.read(input_file_path)
|
||||
data = self.json_parser.parse(content)
|
||||
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
@@ -12,24 +15,34 @@ from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||
# the collators later on to pad the datasets
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
class TokenizedPromptDataset(IterableDataset):
|
||||
def __init__(
|
||||
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.dataset = dataset
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
# Loop through the entire dataset
|
||||
for example in iterator:
|
||||
try:
|
||||
yield self.prompt_tokenizer.tokenize_prompt(example)
|
||||
except InvalidDataException:
|
||||
pass
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
@@ -42,7 +55,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
tokenizer,
|
||||
datasets,
|
||||
@@ -63,14 +76,21 @@ class ConstantLengthDataset(IterableDataset):
|
||||
self.tokens_dtype = torch.int64
|
||||
|
||||
def __iter__(self):
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
for dataset in self.datasets:
|
||||
idx = 0
|
||||
iterator = iter(dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
try:
|
||||
example = next(iterator)
|
||||
idx += 1
|
||||
except StopIteration:
|
||||
more_examples = False
|
||||
example = None
|
||||
@@ -82,10 +102,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if (
|
||||
not example_len
|
||||
or buffer_len + int(add_concat_token) + example_len
|
||||
> self.seq_length
|
||||
if not example_len or (
|
||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
@@ -94,24 +112,34 @@ class ConstantLengthDataset(IterableDataset):
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if (
|
||||
labels.size() == input_ids.size()
|
||||
and attention_mask.size() == input_ids.size()
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
else:
|
||||
logging.warning(
|
||||
LOG.warning(
|
||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||
)
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"position_ids": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
idx = 1
|
||||
|
||||
if example:
|
||||
# FIXME
|
||||
# just going to drop data points that are too long
|
||||
if len(example["input_ids"]) <= self.seq_length:
|
||||
input_ids = example["input_ids"]
|
||||
@@ -127,13 +155,17 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids, dtype=self.tokens_dtype
|
||||
)
|
||||
attention_mask_with_concat = torch.tensor(
|
||||
attention_mask, dtype=self.tokens_dtype
|
||||
[idx * m for m in attention_mask], dtype=torch.int16
|
||||
)
|
||||
labels_with_concat = torch.tensor(
|
||||
labels, dtype=self.tokens_dtype
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
len(input_ids), dtype=self.tokens_dtype
|
||||
)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer["position_ids"].append(position_ids)
|
||||
buffer_len += len(input_ids)
|
||||
|
||||
33
src/axolotl/logging_config.py
Normal file
33
src/axolotl/logging_config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Logging configuration settings"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging.config import dictConfig
|
||||
from typing import Any, Dict
|
||||
|
||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "simple",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
},
|
||||
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
|
||||
"loggers": {
|
||||
"axolotl": {"handlers": ["console"], "level": "DEBUG", "propagate": False},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configure with default logging"""
|
||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||
@@ -1,17 +1,24 @@
|
||||
"""Flash attention monkey patch for llama model"""
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
|
||||
def forward(
|
||||
@@ -27,6 +34,7 @@ def forward(
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
@@ -74,43 +82,79 @@ def forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif position_ids.shape[0] == 1:
|
||||
# special handling using sample packing
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_q_lens = cu_q_lens.squeeze()
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
x_unpad,
|
||||
"nnz (three h d) -> nnz three h d",
|
||||
three=3,
|
||||
h=nheads,
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
x_unpad,
|
||||
cu_q_lens,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
|
||||
return (
|
||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
284
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
284
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
logging.error("xformers not found! Please install it before trying to use it.")
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = transformers.models.llama.modeling_llama.repeat_kv(
|
||||
key_states, self.num_key_value_groups
|
||||
)
|
||||
value_states = transformers.models.llama.modeling_llama.repeat_kv(
|
||||
value_states, self.num_key_value_groups
|
||||
)
|
||||
|
||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
# attn_bias=attention_mask,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# end x-formers vs. not x-formers if-else block
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
52
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
52
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def hijack_expand_mask():
|
||||
import transformers
|
||||
|
||||
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
|
||||
_expand_mask
|
||||
)
|
||||
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
302
src/axolotl/monkeypatch/relora.py
Normal file
302
src/axolotl/monkeypatch/relora.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# pylint: skip-file
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
elif key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
|
||||
assert os.path.exists(
|
||||
self.last_full_model
|
||||
), "for ReLORA base_model must be a local path"
|
||||
|
||||
self.num_lora_restarts = 0
|
||||
self.need_full_save = False
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
checkpoint_folder,
|
||||
reinit=True,
|
||||
quantized=self.quantised,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantised:
|
||||
self.last_full_model = checkpoint_folder
|
||||
self.num_lora_restarts += 1
|
||||
|
||||
return control
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
):
|
||||
if self.quantised and self.last_full_model != checkpoint_folder:
|
||||
# ensure the latest full parameter save is in the latest checkpoint
|
||||
# folder, so that automatic pruning of checkpoints does not remove it
|
||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||
chunks = glob.glob(
|
||||
f"{self.last_full_model}/model*.safetensors"
|
||||
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
||||
for path in chunks:
|
||||
shutil.move(path, checkpoint_folder)
|
||||
self.last_full_model = checkpoint_folder
|
||||
else:
|
||||
model.model.save_pretrained(checkpoint_folder, save_safetensors=True)
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Dict[str, float],
|
||||
**_kwargs,
|
||||
):
|
||||
logs["num_lora_restarts"] = self.num_lora_restarts
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
else:
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
str(Path(path) / f"{model_name}.index.json")
|
||||
):
|
||||
model_name = "pytorch_model.bin"
|
||||
|
||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||
if os.path.exists(index_path):
|
||||
data = json.load(open(index_path, "r"))
|
||||
return data["weight_map"]
|
||||
return {key + ".weight": model_name for key in keys}
|
||||
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer) -> torch.Tensor:
|
||||
if isinstance(layer, peft.tuners.lora.Linear8bitLt) or isinstance(
|
||||
layer, peft.tuners.lora.Linear4bit
|
||||
):
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight @ layer.lora_A[adapter].weight,
|
||||
getattr(layer, "fan_in_fan_out", False),
|
||||
)
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
else:
|
||||
return layer.get_delta_weight()
|
||||
|
||||
|
||||
def merge_and_save(
|
||||
model: peft.LoraModel,
|
||||
model_src: str,
|
||||
model_dst: str,
|
||||
reinit: bool = False,
|
||||
quantized: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
):
|
||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||
|
||||
if not quantized:
|
||||
for key in key_list:
|
||||
try:
|
||||
_parent, target, _target_name = peft.utils._get_submodules(
|
||||
model.model, key
|
||||
)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
shard_paths = sharded_paths(model_src, key_list)
|
||||
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
for shard_path in unique_shards:
|
||||
out_tensors = {}
|
||||
if shard_path.endswith(".safetensors"):
|
||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||
else:
|
||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||
if "state_dict" in in_tensors:
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
for key in key_list:
|
||||
if (key + ".weight") not in shard_paths or shard_paths[
|
||||
key + ".weight"
|
||||
] != shard_path:
|
||||
continue
|
||||
|
||||
try:
|
||||
_parent, target, _target_name = peft.utils._get_submodules(
|
||||
model.model, key
|
||||
)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
orig_weight = in_tensors[key + ".weight"]
|
||||
old_dev = target.weight.device
|
||||
math_dev = "cpu" if cpu_offload else old_dev
|
||||
|
||||
update = lora_delta_weight(target).detach().to(math_dev)
|
||||
new_weight = orig_weight.to(math_dev) + update
|
||||
out_tensors[key + ".weight"] = new_weight
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
target.weight = (
|
||||
bnb.nn.Params4bit(
|
||||
new_weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=target.weight.compress_statistics,
|
||||
quant_type=target.weight.quant_type,
|
||||
)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
else:
|
||||
target.weight.data = new_weight.to(old_dev)
|
||||
|
||||
for key in in_tensors:
|
||||
if key not in out_tensors:
|
||||
out_tensors[key] = in_tensors[key]
|
||||
del in_tensors
|
||||
|
||||
out_shard_name = shard_path
|
||||
if out_shard_name.startswith("pytorch_model"):
|
||||
out_shard_name = (
|
||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||
+ ".safetensors"
|
||||
)
|
||||
|
||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn)
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if len(unique_shards) > 1:
|
||||
with open(str(Path(model_dst, "model.safetensors.index.json")), "w") as fd:
|
||||
json.dump({"metadata": {}, "weight_map": shard_paths}, fd)
|
||||
103
src/axolotl/monkeypatch/utils.py
Normal file
103
src/axolotl/monkeypatch/utils.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def get_cu_seqlens(attn_mask):
|
||||
"""generate a cumulative sequence length mask for flash attention using attn mask"""
|
||||
if len(attn_mask.shape) == 1:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
|
||||
device = attn_mask.device
|
||||
results = []
|
||||
max_seq_lens = []
|
||||
|
||||
for row in attn_mask:
|
||||
# Exclude zeros to avoid adding their positions to the mask
|
||||
t_non_zeros = row[row != 0]
|
||||
# Find where the sequence number changes (including the first position)
|
||||
seq_change = torch.cat(
|
||||
[
|
||||
torch.tensor([1], dtype=torch.int32, device=device),
|
||||
t_non_zeros[1:] != t_non_zeros[:-1],
|
||||
]
|
||||
)
|
||||
# Get the indices where the sequence changes
|
||||
change_indices = torch.cat(
|
||||
[
|
||||
(seq_change == 1).nonzero(as_tuple=True)[0],
|
||||
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
||||
]
|
||||
)
|
||||
# Calculate the sequence lengths
|
||||
seq_lengths = change_indices[1:] - change_indices[:-1]
|
||||
# Calculate the length of the final sequence or padding
|
||||
final_seq_length = len(row) - change_indices[-1]
|
||||
# Append the length of the final sequence or padding to seq_lengths
|
||||
if final_seq_length.item():
|
||||
seq_lengths = torch.cat(
|
||||
[
|
||||
seq_lengths,
|
||||
torch.tensor(
|
||||
[final_seq_length.item()], dtype=torch.int32, device=device
|
||||
),
|
||||
]
|
||||
)
|
||||
# Calculate the cumulative sequence lengths
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||
)
|
||||
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
device = position_ids.device
|
||||
results = []
|
||||
max_seq_lens = []
|
||||
|
||||
for row in position_ids:
|
||||
# Count the number of consecutive zeros from the right side
|
||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||
|
||||
# Adjust the row to exclude padding
|
||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||
|
||||
# Find where the position resets to 0 (indicating a new sequence)
|
||||
seq_starts = torch.cat(
|
||||
[
|
||||
torch.tensor([True], dtype=torch.bool, device=device),
|
||||
adjusted_row[1:] == 0,
|
||||
]
|
||||
)
|
||||
# Get the indices where the sequence starts
|
||||
start_indices = torch.cat(
|
||||
[
|
||||
(seq_starts).nonzero(as_tuple=True)[0],
|
||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||
]
|
||||
)
|
||||
# Calculate the sequence lengths
|
||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||
# Calculate the cumulative sequence lengths
|
||||
cu_seqlens = torch.cat(
|
||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||
)
|
||||
# Append the padding length to the cumulative sequence lengths
|
||||
if padding_length:
|
||||
cu_seqlens = torch.cat(
|
||||
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
|
||||
)
|
||||
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.models.llama.modeling_llama
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class XposRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scale_base=2048,
|
||||
use_xpos=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.scale_base = scale_base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer("scale", None)
|
||||
self.register_buffer("scale_cached", torch.ones(1))
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
||||
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
seq_len,
|
||||
):
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
||||
self.inv_freq
|
||||
)
|
||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
||||
|
||||
self.register_buffer("freqs_cached", freqs)
|
||||
|
||||
if self.scale is None:
|
||||
self.register_buffer(
|
||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
||||
)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
||||
|
||||
power = (t - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
||||
self.register_buffer("scale_cached", scale)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
||||
freqs = freqs[position_ids, :]
|
||||
if scale.shape[-1] != 1:
|
||||
scale = scale[position_ids, :]
|
||||
|
||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def replace_llama_rope_with_xpos_rope():
|
||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
@@ -7,8 +9,8 @@ def load(strategy, tokenizer, cfg):
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
fn = getattr(m, load_fn)
|
||||
return fn(tokenizer, cfg)
|
||||
except:
|
||||
pass
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
func = getattr(mod, load_fn)
|
||||
return func(tokenizer, cfg)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
@@ -1,21 +1,65 @@
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
InstructionPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.chat.value),
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaConcisePrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
|
||||
class AlpacaChatPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
self.prompt_style = PromptStyle.CHAT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
|
||||
class NoSystemPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Null Prompter with no system prompts
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
turn_format = "{instruction} {input} "
|
||||
turn_no_input_format = "{instruction} "
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for AlpacaQA
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"",
|
||||
@@ -23,9 +67,49 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.chat.value),
|
||||
class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for CamelAI datasets
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["message_1"],
|
||||
"",
|
||||
prompt["message_2"],
|
||||
)
|
||||
|
||||
|
||||
def load_concise(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaConcisePrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_camel_ai(tokenizer, cfg):
|
||||
return CamelAIPromptTokenizingStrategy(
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_no_prompt(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
UnpromptedPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
||||
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.instruct),
|
||||
AlpacaPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_no_prompt(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
160
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
160
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||
"""
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["output"],
|
||||
prompt["system"],
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
system,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt_w_system(
|
||||
system,
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class SystemDataPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
formatted_sys_prompt = (
|
||||
self.system_format.format(system=system)
|
||||
if system and self.system_format
|
||||
else ""
|
||||
)
|
||||
if input:
|
||||
res = formatted_sys_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
else:
|
||||
res = formatted_sys_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
||||
"""
|
||||
Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
|
||||
"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
|
||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for OpenOrca datasets
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"",
|
||||
prompt["response"],
|
||||
prompt["system_prompt"],
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return load_chat(tokenizer, cfg)
|
||||
|
||||
|
||||
def load_instruct(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_chat(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_open_orca(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_open_orca_chatml(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
67
src/axolotl/prompt_strategies/context_qa.py
Normal file
67
src/axolotl/prompt_strategies/context_qa.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
# article, unanswerable_question, question, answer
|
||||
def load_404(tokenizer, cfg):
|
||||
return AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
AlpacaContextPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaContextPromptTokenizingStrategy(
|
||||
AlpacaContextPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"Use the following contextual information to concisely answer the question.\n"
|
||||
)
|
||||
system_no_input_prompt = (
|
||||
"Use the following contextual information to concisely answer the question.\n"
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question and answer
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"] + "\n===\n" + prompt["question"],
|
||||
"",
|
||||
prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
InstructionPromptTokenizingStrategy
|
||||
):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question that can't be answered
|
||||
from the context and a default response to that effect
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
|
||||
"",
|
||||
"The context provided does not contain any information about your inquiry. "
|
||||
"Therefore, I'm unable to answer your question based on the given context.",
|
||||
)
|
||||
@@ -1,11 +1,18 @@
|
||||
from typing import Union, Generator
|
||||
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
||||
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
|
||||
|
||||
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Creative Answering
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
question = prompt["instruction"]
|
||||
answer = prompt[
|
||||
"revision"
|
||||
@@ -18,6 +25,10 @@ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrat
|
||||
|
||||
|
||||
class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Creative Critique
|
||||
"""
|
||||
|
||||
user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
|
||||
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
|
||||
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
|
||||
@@ -49,12 +60,16 @@ Question: {question}
|
||||
Answer: {answer}
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
scores = yaml.dump(
|
||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["scores"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
critiques = yaml.dump(
|
||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["critiques"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
evaluation = scores + critiques
|
||||
question = prompt["instruction"]
|
||||
@@ -67,6 +82,10 @@ Answer: {answer}
|
||||
|
||||
|
||||
class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Creative Revise
|
||||
"""
|
||||
|
||||
user_prompt = """Definitions:
|
||||
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
|
||||
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
|
||||
@@ -81,12 +100,16 @@ Evaluation:
|
||||
{evaluation}
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
scores = yaml.dump(
|
||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["scores"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
critiques = yaml.dump(
|
||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["critiques"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
evaluation = scores + critiques
|
||||
question = prompt["instruction"]
|
||||
@@ -101,13 +124,19 @@ Evaluation:
|
||||
|
||||
|
||||
class CreativePrompterBase:
|
||||
"""
|
||||
Base class for Creative Prompters
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None,
|
||||
input: Union[ # pylint: disable=redefined-builtin, unused-argument
|
||||
None, str
|
||||
] = None,
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
if self.system_prompt:
|
||||
@@ -120,30 +149,51 @@ class CreativePrompterBase:
|
||||
|
||||
|
||||
class CreativeAnswerPrompter(CreativePrompterBase):
|
||||
"""
|
||||
Prompter for Creative Answering
|
||||
"""
|
||||
|
||||
system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
|
||||
|
||||
|
||||
class CreativeCritiquePrompter(CreativePrompterBase):
|
||||
"""
|
||||
Prompter for Creative Critique
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
|
||||
class CreativeRevisePrompter(CreativePrompterBase):
|
||||
"""
|
||||
Prompter for Creative Revise
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
|
||||
def load_answer(tokenizer, cfg):
|
||||
return CreativeAnsweringPromptTokenizingStrategy(
|
||||
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeAnswerPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_critique(tokenizer, cfg):
|
||||
return CreativeCritiquePromptTokenizingStrategy(
|
||||
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeCritiquePrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_revise(tokenizer, cfg):
|
||||
return CreativeRevisePromptTokenizingStrategy(
|
||||
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeRevisePrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
205
src/axolotl/prompt_strategies/llama2_chat.py
Normal file
205
src/axolotl/prompt_strategies/llama2_chat.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Prompt Strategy for finetuning Llama2 chat models
|
||||
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
|
||||
|
||||
This implementation is based on the Vicuna PR and the fastchat repo, see also:
|
||||
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
|
||||
|
||||
Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
|
||||
|
||||
E.g. in the config.yml:
|
||||
```
|
||||
datasets:
|
||||
- path: llama_finetune_train.jsonl
|
||||
type: llama2_chat
|
||||
```
|
||||
|
||||
The dataset itself should look like this:
|
||||
```
|
||||
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
|
||||
```
|
||||
in a jsonl file. The first message should be from the human, the second from gpt.
|
||||
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
|
||||
|
||||
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2ChatConversation:
|
||||
"""A class that manages prompt templates and keeps all conversation history.
|
||||
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
|
||||
|
||||
name: str = "llama2"
|
||||
# The system prompt
|
||||
system: str = (
|
||||
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
)
|
||||
roles: Sequence[str] = ("[INST]", "[/INST]")
|
||||
messages: List[List[str]] = field(default_factory=list)
|
||||
offset: int = 0
|
||||
sep = " "
|
||||
sep2 = " </s><s>"
|
||||
stop_token_ids = [2]
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if (i == len(self.messages) - 1) and (role == self.roles[0]):
|
||||
# last message is from user (due to length),
|
||||
# return prompt without it for training
|
||||
return ret
|
||||
if i == 0:
|
||||
ret += self.system + message.strip()
|
||||
else:
|
||||
ret += role + " " + message.strip() + seps[i % 2]
|
||||
return ret
|
||||
|
||||
def append_message(self, role: str, message: str):
|
||||
"""Append a new message."""
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sequence_len = 4096
|
||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
conv = next(self.prompter.build_prompt(prompt))
|
||||
conversation_str = conv.get_prompt()
|
||||
|
||||
# Tokenize conversations
|
||||
input_ids = self.tokenizer(
|
||||
conversation_str,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=self.sequence_len,
|
||||
truncation=True,
|
||||
).input_ids[0]
|
||||
target = input_ids.clone()
|
||||
|
||||
# Mask targets. Only compute loss on the assistant outputs.
|
||||
sep = conv.roles[1]
|
||||
|
||||
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
|
||||
|
||||
turns = conversation_str.split(conv.sep2)
|
||||
cur_len = 1
|
||||
target[:cur_len] = IGNORE_TOKEN_ID
|
||||
for turn in turns:
|
||||
if turn == "":
|
||||
break
|
||||
turn_len = len(self.tokenizer(turn).input_ids)
|
||||
|
||||
parts = turn.split(sep)
|
||||
if len(parts) != 2:
|
||||
break
|
||||
parts[0] += sep
|
||||
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
|
||||
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
|
||||
|
||||
# Ignore the user instructions
|
||||
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
|
||||
cur_len += turn_len + 2 # due to length of role token
|
||||
|
||||
target[cur_len:] = IGNORE_TOKEN_ID
|
||||
|
||||
if cur_len < self.sequence_len:
|
||||
if cur_len != total_len:
|
||||
target[:] = IGNORE_TOKEN_ID
|
||||
logging.warning(
|
||||
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||
f" (ignored)"
|
||||
)
|
||||
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
|
||||
input_ids = input_ids.tolist()
|
||||
target = target.tolist()
|
||||
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
|
||||
# follows the original llama implementation
|
||||
for i in range(2, total_len - 2):
|
||||
if input_ids[i] == 29961:
|
||||
input_ids[i] = 518
|
||||
if target[i] == 29961:
|
||||
target[i] = 518
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": target,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for Llama2 models.
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
)
|
||||
|
||||
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
|
||||
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
|
||||
source = source["conversations"] # fix data structure for datasets
|
||||
|
||||
# if system prompt provided, use it
|
||||
if source[0]["from"] == "system":
|
||||
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
|
||||
source = source[1:]
|
||||
else:
|
||||
system = self.system_prompt
|
||||
|
||||
conv = Llama2ChatConversation(system=system)
|
||||
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
if roles[source[0]["from"]] != conv.roles[0]:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
|
||||
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
|
||||
return LLama2ChatTokenizingStrategy(
|
||||
Llama2ChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
46
src/axolotl/prompt_strategies/orcamini.py
Normal file
46
src/axolotl/prompt_strategies/orcamini.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Prompt Strategy for finetuning Orca Mini (v2) models
|
||||
see also https://huggingface.co/psmathur/orca_mini_v2_7b for more information
|
||||
|
||||
Use dataset type: orcamini in conig.yml to use this prompt style.
|
||||
|
||||
Compared to the alpaca_w_system.open_orca dataset type,
|
||||
this one specifies the system prompt with "### System:".
|
||||
|
||||
Not suited/tested for multiple-turn conversations without further adjustments.
|
||||
"""
|
||||
from typing import Generator, Union
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
|
||||
class OrcaMiniPrompter(AlpacaPrompter):
|
||||
"""Adjusted Prompter for Orca Mini (v2) datasets"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_no_input_format = (
|
||||
"### System:\n{system}\n\n### User:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
res = self.turn_no_input_format.format(system=system, instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
OrcaMiniPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -1,29 +1,36 @@
|
||||
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Generator
|
||||
from typing import Generator, List, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
PromptTokenizingStrategy,
|
||||
parse_tokenized_to_result,
|
||||
tokenize_prompt_default,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
|
||||
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
bot_prefix_token_ids = []
|
||||
"""
|
||||
Tokenizing strategy for Pygmalion.
|
||||
"""
|
||||
|
||||
bot_prefix_token_ids: List[int] = []
|
||||
|
||||
def __init__(self, prompter, tokenizer, *args, **kwargs):
|
||||
super().__init__(prompter, tokenizer)
|
||||
super().__init__(prompter, tokenizer, *args, **kwargs)
|
||||
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
|
||||
self.bot_prefix_token_ids = res["input_ids"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
current_len = 0
|
||||
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
||||
result, current_len = tokenize_prompt_default()
|
||||
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
||||
role, message = part
|
||||
if role == "system":
|
||||
prefix = "<|system|>"
|
||||
@@ -59,47 +66,31 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
*copy.deepcopy(res["input_ids"])
|
||||
][len(self.bot_prefix_token_ids) :]
|
||||
else:
|
||||
logging.warning(f"unknown role in conversation: {role}")
|
||||
LOG.warning(f"unknown role in conversation: {role}")
|
||||
res = defaultdict(lambda: [])
|
||||
input_ids = res["input_ids"]
|
||||
input_len = len(input_ids)
|
||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||
result["attention_mask"][current_len : current_len + input_len] = [
|
||||
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
||||
]
|
||||
result["labels"][current_len : current_len + input_len] = labels
|
||||
current_len += input_len
|
||||
return result
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class PygmalionPrompter:
|
||||
"""
|
||||
Prompter for Pygmalion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
||||
def build_prompt(
|
||||
self, source, *args, **kwargs # pylint: disable=unused-argument
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
for msg in source:
|
||||
yield msg["role"], msg["value"]
|
||||
|
||||
|
||||
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
||||
"""
|
||||
|
||||
# title, text, explanation
|
||||
def get_conversation_thread(self, prompt):
|
||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
||||
return [
|
||||
{"from": "human", "value": "Tell me a joke."},
|
||||
{"from": "gpt", "value": title + prompt["text"]},
|
||||
{"from": "human", "value": "Why is that joke funny?"},
|
||||
{"from": "gpt", "value": prompt["explanation"]},
|
||||
]
|
||||
67
src/axolotl/prompt_strategies/sharegpt_simple.py
Normal file
67
src/axolotl/prompt_strategies/sharegpt_simple.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_guanaco(tokenizer, cfg):
|
||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
return turns
|
||||
|
||||
|
||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
@@ -1,24 +1,35 @@
|
||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>"
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>"
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
pass
|
||||
"""
|
||||
Exception raised when the data is invalid
|
||||
"""
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
"""
|
||||
Abstract class for tokenizing strategies
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
@@ -35,59 +46,27 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def tokenize_prompt(self, prompt):
|
||||
pass
|
||||
|
||||
@functools.cache
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@functools.cache
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction, input, response = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction, input, response):
|
||||
return next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
response,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
@@ -111,8 +90,64 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
return result
|
||||
|
||||
|
||||
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(
|
||||
self, prompt
|
||||
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response # pylint: disable=redefined-builtin
|
||||
):
|
||||
return next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
response,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Alpaca prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
@@ -121,7 +156,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Alpaca Multiple Choice prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
|
||||
@@ -130,7 +169,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
|
||||
|
||||
|
||||
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Jeopardy prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
prompt["category"],
|
||||
@@ -139,7 +182,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for OpenAssistant prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["INSTRUCTION"],
|
||||
"",
|
||||
@@ -148,7 +195,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
||||
|
||||
|
||||
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for SummarizeTLDR prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"],
|
||||
"",
|
||||
@@ -157,7 +208,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
||||
|
||||
|
||||
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for GPTeacher prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
@@ -166,7 +221,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for NomicGPT4All prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["prompt"],
|
||||
"",
|
||||
@@ -175,28 +234,34 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> str:
|
||||
return prompt["text"]
|
||||
"""
|
||||
Tokenizing strategy for Completion prompts.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, None, None)
|
||||
full_prompt = self._build_full_prompt(prompt["text"], None, None)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction, input, response):
|
||||
return next(iter(self.prompter.build_prompt(instruction)))
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response
|
||||
): # pylint: disable=redefined-builtin
|
||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||
|
||||
|
||||
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Reflection prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
(
|
||||
instruction,
|
||||
input,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
@@ -223,7 +288,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, output, reflection, corrected
|
||||
): # pylint: disable=redefined-builtin
|
||||
return next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
@@ -236,7 +303,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
)
|
||||
)
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True):
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
@@ -257,7 +324,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
|
||||
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
||||
"""
|
||||
Tokenizing strategy for Alpaca Reflection prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
@@ -268,20 +339,19 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
|
||||
|
||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
current_len = 0
|
||||
result, current_len = tokenize_prompt_default()
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
try:
|
||||
for i, part in enumerate(
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
@@ -289,7 +359,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
part = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=True
|
||||
part.strip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
@@ -300,32 +372,39 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
part = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistent response, should end with an eos token
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=True, strip_bos_token=True
|
||||
part.strip(),
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "SYSTEM:":
|
||||
part = part[1] # Ignore the system role from preamble
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
logging.warning("unhandled role: " + part[0])
|
||||
else:
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
input_ids = res["input_ids"]
|
||||
input_len = len(input_ids)
|
||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||
result["attention_mask"][current_len : current_len + input_len] = [
|
||||
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
||||
]
|
||||
result["labels"][current_len : current_len + input_len] = labels
|
||||
current_len += input_len
|
||||
LOG.warning(f"unhandled role: {part[0]}")
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as e:
|
||||
raise InvalidDataException(str(e))
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
@@ -349,3 +428,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Returns the default values for the tokenize prompt function
|
||||
"""
|
||||
|
||||
result: Dict[str, List[int]] = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
current_len = 0
|
||||
return result, current_len
|
||||
|
||||
|
||||
def parse_tokenized_to_result(
|
||||
result: Dict[str, List[int]],
|
||||
current_len: int,
|
||||
res: Dict[str, List[int]],
|
||||
labels: List[int],
|
||||
pad_token_id: Union[int, None] = None,
|
||||
) -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
|
||||
"""
|
||||
|
||||
input_ids = res["input_ids"]
|
||||
input_len = len(input_ids)
|
||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||
result["attention_mask"][current_len : current_len + input_len] = [
|
||||
1 if x != pad_token_id else 0 for x in input_ids
|
||||
]
|
||||
result["labels"][current_len : current_len + input_len] = labels
|
||||
current_len += input_len
|
||||
|
||||
return result, current_len
|
||||
|
||||
@@ -1,110 +1,167 @@
|
||||
import copy
|
||||
"""Module containing prompters"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any, Union, Generator
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
instruct = "instruct"
|
||||
chat = "chat"
|
||||
"""
|
||||
Enum for prompt styles
|
||||
"""
|
||||
|
||||
INSTRUCT = "instruct"
|
||||
CHAT = "chat"
|
||||
CHATML = "chatml"
|
||||
|
||||
|
||||
class AlpacaPrompter:
|
||||
"""
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
prompt_style = None
|
||||
system_format: str
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
|
||||
def __init__(self, prompt_style=PromptStyle.instruct.value):
|
||||
self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
|
||||
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
||||
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
def match_prompt_style(self):
|
||||
if self.prompt_style == PromptStyle.instruct.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
# pylint: disable=duplicate-code
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
self.turn_no_input_format = (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.prompt_no_input = (
|
||||
self.system_no_input_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
self.system_format = "### System:\n{system}\n\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
self.response_split = "### Response:"
|
||||
if self.prompt_style == PromptStyle.chat.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
)
|
||||
self.prompt_no_input = (
|
||||
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = self.prompt_input.format(instruction=instruction, input=input)
|
||||
res = self.system_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
else:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for alpaca no system prompt
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
|
||||
class JeopardyPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for Jeopardy
|
||||
"""
|
||||
|
||||
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
|
||||
|
||||
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for multiple choice explain
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"Choose the answer that best answers the question. Explain your reasoning."
|
||||
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
||||
)
|
||||
system_no_input_prompt = (
|
||||
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
||||
)
|
||||
|
||||
|
||||
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
||||
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
|
||||
"""
|
||||
Prompter for multiple choice concise
|
||||
"""
|
||||
|
||||
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
||||
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class SummarizeTLDRPrompter(AlpacaPrompter):
|
||||
prompt_no_input = (
|
||||
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
)
|
||||
"""
|
||||
Prompter for summarize TLDR
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
"""
|
||||
Prompter for completion
|
||||
"""
|
||||
|
||||
def build_prompt(
|
||||
self, instruction: str, input=None, output=None
|
||||
self,
|
||||
instruction: str,
|
||||
input=None, # pylint: disable=redefined-builtin, unused-argument
|
||||
output=None, # pylint: disable=unused-argument
|
||||
) -> Generator[str, None, None]:
|
||||
yield instruction
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.strip()
|
||||
|
||||
|
||||
class GPTeacherPrompter(AlpacaPrompter):
|
||||
...
|
||||
"""
|
||||
Prompter for GPTeacher
|
||||
"""
|
||||
|
||||
|
||||
class NomicGPT4AllPrompter(AlpacaPrompter):
|
||||
...
|
||||
"""
|
||||
Prompter for NomicGPT4All
|
||||
"""
|
||||
|
||||
|
||||
class ReflectAlpacaPrompter:
|
||||
"""
|
||||
Prompter for ReflectAlpaca
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
||||
|
||||
@@ -120,7 +177,7 @@ class ReflectAlpacaPrompter:
|
||||
self.match_prompt_style()
|
||||
|
||||
def match_prompt_style(self):
|
||||
if self.prompt_style == PromptStyle.instruct.value:
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
@@ -131,7 +188,7 @@ class ReflectAlpacaPrompter:
|
||||
)
|
||||
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
||||
self.response_split = "### Final Response:"
|
||||
if self.prompt_style == PromptStyle.chat.value:
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
)
|
||||
@@ -146,7 +203,7 @@ class ReflectAlpacaPrompter:
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
@@ -159,14 +216,13 @@ class ReflectAlpacaPrompter:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
if output and reflection and corrected:
|
||||
label = self.agent_label.format(
|
||||
output=output, reflection=reflection, corrected=corrected
|
||||
output=output,
|
||||
reflection=reflection,
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
@@ -187,18 +243,18 @@ class Conversation:
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
sep2: Optional[str] = None
|
||||
|
||||
def get_prompt(self) -> Generator[str, None, None]:
|
||||
seps = [self.sep, self.sep2]
|
||||
preamble = self.system + seps[0]
|
||||
yield preamble
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
||||
# seps = [self.sep, self.sep2]
|
||||
preamble = self.system + self.sep
|
||||
yield ("SYSTEM:", preamble)
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield (role + ":", " " + message)
|
||||
else:
|
||||
logging.warning("role with empty message: " + role)
|
||||
yield (role + ":",)
|
||||
LOG.warning(f"role with empty message: {role}")
|
||||
yield (role + ":", "")
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
@@ -215,32 +271,40 @@ class Conversation:
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
conv_vicuna_v1_1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter:
|
||||
def __init__(self, prompt_style=None):
|
||||
if prompt_style != PromptStyle.chat.value:
|
||||
raise Exception(
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
||||
if prompt_style != PromptStyle.CHAT.value:
|
||||
raise ValueError(
|
||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||
)
|
||||
system: str = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
)
|
||||
self._conversation = Conversation(
|
||||
system=system,
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
# def match_prompt_style(self):
|
||||
# if self.prompt_style == PromptStyle.chat.value:
|
||||
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||
# self.response_split = "ASSISTANT:"
|
||||
|
||||
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
# ignore the system prompt if provided
|
||||
if source[0]["from"] == "system":
|
||||
source.pop(0)
|
||||
@@ -250,7 +314,7 @@ class ShareGPTPrompter:
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
|
||||
conv = conv_vicuna_v1_1.copy()
|
||||
conv = self._conversation.copy()
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
try:
|
||||
@@ -261,14 +325,14 @@ class ShareGPTPrompter:
|
||||
):
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as e:
|
||||
except IndexError as err:
|
||||
# sometimes there is a bing or system chat
|
||||
raise e
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_prompt():
|
||||
|
||||
23
src/axolotl/utils/bench.py
Normal file
23
src/axolotl/utils/bench.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
|
||||
|
||||
def gpu_memory_usage(device):
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
|
||||
# NB torch.cuda.memory_usage returns zero so we use lower level api
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
log.info(
|
||||
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
|
||||
)
|
||||
@@ -1,16 +1,25 @@
|
||||
"""Callbacks for Trainer class"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import (
|
||||
Seq2SeqTrainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
TrainerState,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the PEFT adapter"""
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -19,10 +28,71 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
kwargs["model"].save_pretrained(
|
||||
peft_model_path, save_safetensors=args.save_safetensors
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the BetterTransformer wrapped model"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
if control.should_save:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
model = BetterTransformer.reverse(kwargs["model"])
|
||||
model.save_pretrained(checkpoint_folder)
|
||||
# FIXME - need to cleanup old checkpoints
|
||||
|
||||
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
|
||||
class PrintGPUStatsCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""Callback to print GPU utilization"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if not self.logged:
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
return control
|
||||
|
||||
121
src/axolotl/utils/collators.py
Normal file
121
src/axolotl/utils/collators.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||
The tokenizer used for encoding the data.
|
||||
model ([`PreTrainedModel`]):
|
||||
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
||||
prepare the *decoder_input_ids*
|
||||
|
||||
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence is provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
label_pad_token_id (`int`, *optional*, defaults to -100):
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
labels = None
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
for feature_name, pad_token_id in [
|
||||
("labels", self.label_pad_token_id),
|
||||
("position_ids", self.position_pad_token_id),
|
||||
]:
|
||||
feat = (
|
||||
[feature[feature_name] for feature in features]
|
||||
if feature_name in features[0].keys()
|
||||
else None
|
||||
)
|
||||
labels = feat if feat and feature_name == "labels" else labels
|
||||
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||
# same length to return tensors.
|
||||
if feat is not None:
|
||||
max_feature_length = max(len(l) for l in feat) # noqa: E741
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_feature_length = (
|
||||
(max_feature_length + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
* self.pad_to_multiple_of
|
||||
)
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for feature in features:
|
||||
remainder = [pad_token_id] * (
|
||||
max_feature_length - len(feature[feature_name])
|
||||
)
|
||||
if isinstance(feature[feature_name], list):
|
||||
feature[feature_name] = (
|
||||
feature[feature_name] + remainder
|
||||
if padding_side == "right"
|
||||
else remainder + feature[feature_name]
|
||||
)
|
||||
elif padding_side == "right":
|
||||
feature[feature_name] = np.concatenate(
|
||||
[feature[feature_name], remainder]
|
||||
).astype(np.int64)
|
||||
else:
|
||||
feature[feature_name] = np.concatenate(
|
||||
[remainder, feature[feature_name]]
|
||||
).astype(np.int64)
|
||||
|
||||
features = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
# prepare decoder_input_ids
|
||||
if (
|
||||
labels is not None
|
||||
and self.model is not None
|
||||
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||
):
|
||||
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||
labels=features["labels"]
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
@@ -1,43 +1,49 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import (
|
||||
load_from_disk,
|
||||
load_dataset,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
concatenate_datasets,
|
||||
DatasetDict,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
SummarizeTLDRPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
CompletionPrompter,
|
||||
GPTeacherPrompter,
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
ShareGPTPrompter,
|
||||
JeopardyPrompter,
|
||||
CompletionPrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
)
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
@@ -45,11 +51,13 @@ def load_tokenized_prepared_datasets(
|
||||
) -> DatasetDict:
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
ds_hash = str(
|
||||
md5(
|
||||
md5( # nosec
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
@@ -65,59 +73,96 @@ def load_tokenized_prepared_datasets(
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
...
|
||||
elif any(prepared_ds_path.glob("*")):
|
||||
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
logging.info("Prepared dataset loaded from disk...")
|
||||
LOG.info("Prepared dataset loaded from disk...")
|
||||
else:
|
||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||
logging.info("Loading raw datasets...")
|
||||
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||
LOG.info("Loading raw datasets...")
|
||||
|
||||
if cfg.seed:
|
||||
seed = cfg.seed
|
||||
else:
|
||||
LOG.info("No seed provided, using default seed of 42")
|
||||
seed = 42
|
||||
|
||||
datasets = []
|
||||
# pylint: disable=invalid-name
|
||||
for d in cfg.datasets:
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
|
||||
load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
streaming=True,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
if Path(d.path).exists():
|
||||
ds: Dataset = load_dataset(
|
||||
"json", data_files=d.path, streaming=False, split=None
|
||||
)
|
||||
elif ds_from_hub:
|
||||
if d.data_files:
|
||||
ds: Dataset = load_dataset(
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
streaming=False,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
use_auth_token=use_auth_token,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=d.name,
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token)
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
streaming=False,
|
||||
data_files=d.data_files,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||
repo_id=d.path,
|
||||
repo_type="dataset",
|
||||
filename=d.data_files,
|
||||
)
|
||||
ds = load_dataset(
|
||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
||||
)
|
||||
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||
if not ds:
|
||||
raise Exception("unhandled dataset load")
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -218,20 +263,24 @@ def load_tokenized_prepared_datasets(
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
||||
suffix = ""
|
||||
if ":load_" in d.type:
|
||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
samples = []
|
||||
for d in datasets:
|
||||
samples = samples + [i for i in d]
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||
if len(datasets) > 1:
|
||||
LOG.info("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
||||
)
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
if cfg.push_dataset_to_hub:
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
@@ -242,8 +291,10 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||
) -> (Dataset, Dataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -256,13 +307,15 @@ def load_prepare_datasets(
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||
ds_hash = str(
|
||||
md5(
|
||||
md5( # nosec
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ seed
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|".join(
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
@@ -278,26 +331,27 @@ def load_prepare_datasets(
|
||||
use_auth_token = cfg.hf_use_auth_token
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
...
|
||||
elif any(prepared_ds_path.glob("*")):
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
||||
)
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
logging.info("Prepared packed dataset loaded from disk...")
|
||||
LOG.info("Prepared packed dataset loaded from disk...")
|
||||
if cfg.push_dataset_to_hub:
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
@@ -316,17 +370,16 @@ def load_prepare_datasets(
|
||||
[dataset],
|
||||
seq_length=max_packed_sequence_len,
|
||||
)
|
||||
logging.info(
|
||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
||||
)
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
|
||||
# filter out bad data
|
||||
# TODO convert to dataset.filter(...)
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
if len(d["input_ids"]) <= cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
@@ -334,16 +387,17 @@ def load_prepare_datasets(
|
||||
)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
||||
)
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
if cfg.push_dataset_to_hub:
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
private=True,
|
||||
)
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
@@ -351,15 +405,179 @@ def load_prepare_datasets(
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
logging.info(
|
||||
LOG.info(
|
||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||
)
|
||||
dataset = dataset.shard(
|
||||
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
||||
num_shards=cfg.dataset_shard_num,
|
||||
index=cfg.dataset_shard_idx,
|
||||
)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
if cfg.val_set_size:
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = hashlib.md5(
|
||||
to_hash_train.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
test_fingerprint = hashlib.md5(
|
||||
to_hash_test.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
if is_main_process():
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
else:
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(tokenizer, max_tokens, examples):
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
# Convert to PyTorch tensors
|
||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||
new_input_ids = []
|
||||
new_attention_mask = []
|
||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||
for i, _ in enumerate(input_ids):
|
||||
input_ids[i] = torch.cat(
|
||||
(
|
||||
input_ids[i],
|
||||
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||
|
||||
# Concatenate tokens so that their lengths are less than max_tokens
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
for ids, mask in zip(input_ids, attention_mask):
|
||||
if buffer_input_ids.numel() == max_tokens:
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
else:
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
|
||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
|
||||
ret = {
|
||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||
"labels": [seq.tolist() for seq in new_input_ids],
|
||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||
}
|
||||
|
||||
LOG.debug(len(ret["input_ids"]))
|
||||
return ret
|
||||
|
||||
|
||||
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
dataset = load_dataset(path, streaming=True, split="train")
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
# TODO dynamically figure out which columns/features to remove
|
||||
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
||||
return dataset
|
||||
|
||||
288
src/axolotl/utils/dataloader.py
Normal file
288
src/axolotl/utils/dataloader.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# pylint: skip-file
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result, len(a)
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
"""
|
||||
:param lengths: array of lengths of each sample
|
||||
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||
:param rank: rank for this process
|
||||
:param c: length of tokens per batch
|
||||
:param n: number of ranks
|
||||
:return:
|
||||
"""
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
result_totseqs = []
|
||||
|
||||
while True:
|
||||
# binary search [left, right)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length left
|
||||
batch, tot_seqs = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
"""
|
||||
Chunk data into tuples of length n
|
||||
"""
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
it = iter(iterable)
|
||||
while batch := tuple(itertools.islice(it, n)):
|
||||
yield batch
|
||||
|
||||
|
||||
def hash_indices(lst: List[int]) -> str:
|
||||
# Convert the list of integers to a string representation
|
||||
concatenated = ",".join(map(str, lst))
|
||||
|
||||
# Generate the hash
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(concatenated.encode())
|
||||
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
class MultipackDistributedDataloader:
|
||||
"""Unpadded data loading using Multipack.
|
||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Any,
|
||||
collate_fn: Callable,
|
||||
seq_max_length: int = 2048,
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
self.lengths = (
|
||||
dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
)
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
indices = [idx for idx in self.sampler]
|
||||
else:
|
||||
indices = range(0, len(self.dataset))
|
||||
|
||||
LOG.info(hash_indices(indices))
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, totseqs, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
|
||||
def _len_est(self):
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return (
|
||||
math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.seq_max_length
|
||||
// self.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||
# the same share of total tokens
|
||||
# if not self.eff_total_used:
|
||||
# batches, _ = self.generate_batches(set_stats=True)
|
||||
# LOG.info(
|
||||
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
# f"actual packing efficiency: {self.efficiency()}"
|
||||
# )
|
||||
return max(1, self._len_est())
|
||||
|
||||
def len_w_stats(self):
|
||||
if not self.eff_total_used:
|
||||
batches, _ = self.generate_batches(set_stats=True)
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"actual packing efficiency: {self.efficiency()}"
|
||||
)
|
||||
return max(1, self._len_est())
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Module containing the DictDefault class"""
|
||||
|
||||
from addict import Dict
|
||||
|
||||
|
||||
|
||||
41
src/axolotl/utils/distributed.py
Normal file
41
src/axolotl/utils/distributed.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerate = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def load_accelerate():
|
||||
global accelerate # pylint: disable=global-statement
|
||||
accelerate = Accelerator()
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
"""
|
||||
global accelerate # pylint: disable=global-statement
|
||||
if not accelerate:
|
||||
accelerate = Accelerator()
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def barrier():
|
||||
"""
|
||||
Acts as a barrier to wait for all processes. This ensures that all processes
|
||||
reach the barrier before proceeding further.
|
||||
"""
|
||||
if is_distributed():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""
|
||||
Check if the current process is the main process.
|
||||
If not in distributed mode, always return True.
|
||||
"""
|
||||
if not is_distributed():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
@@ -1,60 +1,73 @@
|
||||
"""Module for models and model loading"""
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import (
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
AutoConfig,
|
||||
BitsAndBytesConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
except:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftModel, PeftConfig
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from transformers import PreTrainedTokenizer
|
||||
from peft import PeftConfig # noqa: F401
|
||||
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
if tokenizer.__class__.__name__ in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
@@ -62,8 +75,8 @@ def load_tokenizer(
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, v in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: v})
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: val})
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(list(cfg.tokens))
|
||||
|
||||
@@ -71,39 +84,77 @@ def load_tokenizer(
|
||||
|
||||
|
||||
def load_model(
|
||||
base_model,
|
||||
base_model_config,
|
||||
model_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter="lora",
|
||||
inference=False,
|
||||
):
|
||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
||||
cfg, tokenizer
|
||||
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
adapter = cfg.adapter
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
cfg.is_llama_derived_model = (
|
||||
"llama" in base_model
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or cfg.is_llama_derived_model
|
||||
)
|
||||
|
||||
if is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
LOG.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
if cfg.bf16:
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
MEM_TOKEN,
|
||||
patch_llama_with_landmark_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with landmark attention")
|
||||
patch_llama_with_landmark_attn()
|
||||
|
||||
# Note: This might overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
)
|
||||
|
||||
LOG.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.is_llama_derived_model and (
|
||||
cfg.max_packed_sequence_len or cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16:
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
@@ -114,12 +165,25 @@ def load_model(
|
||||
)
|
||||
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
from peft import prepare_model_for_int8_training
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise e
|
||||
except Exception as err:
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
):
|
||||
try:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
except ImportError:
|
||||
# For backward compatibility
|
||||
from peft import (
|
||||
prepare_model_for_int8_training as prepare_model_for_kbit_training,
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -130,7 +194,7 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
try:
|
||||
if cfg.gptq and is_llama_derived_model:
|
||||
if cfg.gptq and cfg.is_llama_derived_model:
|
||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -151,11 +215,11 @@ def load_model(
|
||||
if len(files) > 0:
|
||||
model_path = str(files[0])
|
||||
else:
|
||||
logging.warning(
|
||||
LOG.warning(
|
||||
"unable to find a cached model file, this will likely fail..."
|
||||
)
|
||||
model_path = str(cache_model_path)
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
model_path = cfg.base_model
|
||||
model, _ = load_llama_model_4bit_low_ram(
|
||||
base_model_config if base_model_config else base_model,
|
||||
@@ -168,13 +232,18 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config, rope_scaling=cfg.rope_scaling
|
||||
)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
|
||||
**model_kwargs,
|
||||
)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
@@ -203,55 +272,97 @@ def load_model(
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif model_type:
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
):
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
LOG.error(
|
||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||
)
|
||||
logging.exception(e)
|
||||
LOG.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
if (
|
||||
((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
|
||||
and not cfg.gptq
|
||||
and (load_in_8bit or cfg.load_in_4bit)
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings
|
||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
LOG.warning(
|
||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if cfg.flash_attention and cfg.is_llama_derived_model:
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch_dtype)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch_dtype)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, adapter)
|
||||
|
||||
@@ -260,15 +371,18 @@ def load_model(
|
||||
|
||||
if cfg.gptq:
|
||||
# Scales to half
|
||||
logging.info("Fitting 4bit scales and zeros to half")
|
||||
for n, m in model.named_modules():
|
||||
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
||||
type(m)
|
||||
LOG.info("Fitting 4bit scales and zeros to half")
|
||||
for _, module in model.named_modules():
|
||||
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
||||
type(module)
|
||||
):
|
||||
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
||||
module.zeros = module.zeros.half()
|
||||
module.scales = module.scales.half()
|
||||
module.bias = module.bias.half()
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
@@ -278,17 +392,20 @@ def load_model(
|
||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||
# so let's only set it for the 4bit, see
|
||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||
setattr(model, 'is_parallelizable', True)
|
||||
setattr(model, 'model_parallel', True)
|
||||
setattr(model, "is_parallelizable", True)
|
||||
setattr(model, "model_parallel", True)
|
||||
|
||||
requires_grad = []
|
||||
for name, param in model.named_parameters(recurse=True):
|
||||
if param.requires_grad:
|
||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||
if len(requires_grad) == 0:
|
||||
logging.warning("there are no parameters that require gradient updates")
|
||||
LOG.warning("there are no parameters that require gradient updates")
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
@@ -298,6 +415,8 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg)
|
||||
if adapter == "llama-adapter":
|
||||
@@ -308,11 +427,7 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import (
|
||||
AdaptionPromptConfig,
|
||||
get_peft_model,
|
||||
PeftModel,
|
||||
)
|
||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
@@ -321,11 +436,10 @@ def load_llama_adapter(model, cfg):
|
||||
)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
logging.info("Loading pretained LORA")
|
||||
LOG.info("Loading pretained LORA")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
device_map=cfg.device_map,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
else:
|
||||
@@ -357,11 +471,7 @@ def find_all_linear_names(bits, model):
|
||||
def load_lora(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
PeftModel,
|
||||
)
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
@@ -373,7 +483,7 @@ def load_lora(model, cfg):
|
||||
bits = 8
|
||||
|
||||
linear_names = find_all_linear_names(bits, model)
|
||||
logging.info(f"found linear modules: {repr(linear_names)}")
|
||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||
|
||||
lora_config = LoraConfig(
|
||||
@@ -391,8 +501,7 @@ def load_lora(model, cfg):
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
device_map=cfg.device_map,
|
||||
# torch_dtype=torch.float16,
|
||||
is_trainable=not cfg.inference,
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
"""Module for custom LRScheduler class"""
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
|
||||
class InterpolatingLogScheduler(LRScheduler):
|
||||
"""
|
||||
A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
|
||||
"""A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
|
||||
@@ -19,7 +28,9 @@ class InterpolatingLogScheduler(LRScheduler):
|
||||
self.num_steps = num_steps
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
|
||||
self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name
|
||||
1 / (num_steps - 1)
|
||||
)
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
@@ -34,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
|
||||
lrs = [self.max_lr for base_lr in self.base_lrs]
|
||||
|
||||
return lrs
|
||||
|
||||
|
||||
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
|
||||
current_step: int,
|
||||
*,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float
|
||||
):
|
||||
if current_step < num_warmup_steps:
|
||||
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(
|
||||
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
||||
)
|
||||
|
||||
|
||||
def get_cosine_schedule_with_quadratic_warmup(
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
num_cycles: float = 0.5,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
"""
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (`int`):
|
||||
The total number of training steps.
|
||||
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
lr_lambda = partial(
|
||||
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from termcolor import colored
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_dataset_labels(dataset, tokenizer):
|
||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||
@@ -17,7 +23,7 @@ def check_example_labels(example, tokenizer):
|
||||
# You can compare the input_ids and labels element-wise
|
||||
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
||||
colored_tokens = []
|
||||
for i, (input_id, label_id, mask) in enumerate(
|
||||
for _, (input_id, label_id, mask) in enumerate(
|
||||
zip(input_ids, labels, attention_mask)
|
||||
):
|
||||
decoded_input_token = tokenizer.decode(input_id)
|
||||
@@ -28,5 +34,7 @@ def check_example_labels(example, tokenizer):
|
||||
)
|
||||
colored_tokens.append(colored_token)
|
||||
|
||||
logging.info(" ".join(colored_tokens))
|
||||
logging.info("\n\n\n")
|
||||
LOG.info(" ".join(colored_tokens))
|
||||
LOG.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
@@ -1,28 +1,242 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback, Trainer
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
PrintGPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.schedulers import (
|
||||
InterpolatingLogScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(Trainer):
|
||||
@torch.jit.script
|
||||
def weighted_cross_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||
):
|
||||
# Flatten the logits, labels, and weights tensors
|
||||
logits = logits.view(
|
||||
-1, logits.size(-1)
|
||||
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
|
||||
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
|
||||
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
|
||||
|
||||
# Compute the unweighted cross entropy loss
|
||||
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||
|
||||
# Apply the weights to the losses and compute their sum
|
||||
return (weights * losses).sum()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def create_weighted_mask(labels: torch.Tensor):
|
||||
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
|
||||
if len(labels.shape) == 1:
|
||||
labels = labels.unsqueeze(0)
|
||||
|
||||
weights = torch.zeros_like(labels).float()
|
||||
for i in range(labels.shape[0]):
|
||||
mask = labels[i] != -100
|
||||
|
||||
# Create a tensor to track group ids
|
||||
group_ids = torch.zeros_like(labels[i]).int()
|
||||
curr_group_id = 0
|
||||
|
||||
for j in range(1, len(labels[i])):
|
||||
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
|
||||
curr_group_id += 1 # start new group
|
||||
group_ids[j] = (
|
||||
curr_group_id if mask[j] else 0
|
||||
) # assign group id if unmasked label
|
||||
|
||||
# Count only unmasked labels in each group
|
||||
group_counts = torch.bincount(group_ids[mask])
|
||||
|
||||
mask_weights = torch.zeros_like(labels[i]).float()
|
||||
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||
|
||||
weights[i] = mask_weights
|
||||
|
||||
return weights.squeeze() # squeeze the output to match the input dimension
|
||||
|
||||
|
||||
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
logits = (
|
||||
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
)
|
||||
if shift_labels:
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
|
||||
weights = create_weighted_mask(labels)
|
||||
return weighted_cross_entropy(logits, labels, weights)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
num_training_steps = num_training_steps
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
@@ -36,10 +250,121 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
def add_position_ids(sample):
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048):
|
||||
return len(sample["input_ids"]) <= sequence_len
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_datasets_caching():
|
||||
try:
|
||||
set_caching_enabled(False)
|
||||
yield
|
||||
finally:
|
||||
set_caching_enabled(True)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.sample_packing:
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info("calculating total_num_tokens")
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_steps = (
|
||||
# match count to len est in dataloader
|
||||
(
|
||||
math.floor(
|
||||
0.99
|
||||
* cfg.total_num_tokens
|
||||
/ cfg.sample_packing_eff_est
|
||||
/ cfg.sequence_len
|
||||
// cfg.batch_size
|
||||
// int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.info(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.micro_batch_size
|
||||
* cfg.num_epochs
|
||||
// cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.info(
|
||||
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||
)
|
||||
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||
return total_num_steps
|
||||
|
||||
|
||||
def setup_fsdp_envs(cfg):
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
||||
if cfg.fsdp_config.fsdp_state_dict_type:
|
||||
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
warmup_steps = (
|
||||
cfg.warmup_steps
|
||||
if cfg.warmup_steps is not None
|
||||
@@ -50,19 +375,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
save_steps = cfg.save_steps
|
||||
eval_steps = cfg.eval_steps
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
|
||||
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
if cfg.gradient_checkpointing is not None:
|
||||
|
||||
if cfg.seed:
|
||||
training_arguments_kwargs["seed"] = cfg.seed
|
||||
|
||||
if cfg.gradient_checkpointing:
|
||||
if cfg.gptq:
|
||||
from alpaca_lora_4bit.gradient_checkpointing import (
|
||||
apply_gradient_checkpointing,
|
||||
@@ -85,6 +412,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
||||
|
||||
if cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
||||
|
||||
# deepspeed
|
||||
if (
|
||||
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
|
||||
@@ -97,7 +427,31 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
# TODO search Path("./") for one
|
||||
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
||||
|
||||
training_args = transformers.TrainingArguments(
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
||||
if cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
||||
if cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
||||
|
||||
if cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = cfg.sample_packing_eff_est
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||
max_seq_length=cfg.sequence_len,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
@@ -107,18 +461,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||
save_strategy="steps" if save_steps else "epoch",
|
||||
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=save_steps,
|
||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True
|
||||
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
|
||||
and cfg.val_set_size > 0
|
||||
and save_steps is not None
|
||||
and save_steps % eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
else False,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
load_best_model_at_end=(
|
||||
cfg.load_best_model_at_end is not False
|
||||
and cfg.val_set_size > 0
|
||||
and cfg.save_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
)
|
||||
or False,
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
@@ -128,6 +483,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -140,7 +497,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if (
|
||||
cfg.optimizer == "adamw_bnb_8bit"
|
||||
and not cfg.gptq
|
||||
and not "deepspeed" in training_arguments_kwargs
|
||||
and "deepspeed" not in training_arguments_kwargs
|
||||
and not cfg.fsdp
|
||||
):
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
@@ -199,6 +556,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
||||
|
||||
callbacks = []
|
||||
callbacks.append(PrintGPUStatsCallback(cfg))
|
||||
|
||||
if cfg.relora_steps:
|
||||
relora_steps = int(cfg.relora_steps)
|
||||
relora_warmup_steps = int(cfg.relora_warmup_steps)
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
(optimizer, lr_scheduler) = trainer_kwargs["optimizers"]
|
||||
trainer_kwargs["optimizers"] = (
|
||||
optimizer,
|
||||
ReLoRAScheduler(optimizer, lr_scheduler, relora_steps, relora_warmup_steps),
|
||||
)
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
@@ -206,28 +576,54 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
||||
if cfg.local_rank == 0 and cfg.adapter in [
|
||||
"lora",
|
||||
"qlora",
|
||||
]: # only save in rank 0
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True,
|
||||
}
|
||||
if cfg.collator_pad_to_longest:
|
||||
data_collator_kwargs["padding"] = "longest"
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
|
||||
LOG.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
else transformers.Trainer
|
||||
else AxolotlTrainer
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
|
||||
@@ -1,7 +1,36 @@
|
||||
"""Module for validating config files"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||
raise ValueError(
|
||||
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
||||
)
|
||||
if cfg.max_packed_sequence_len:
|
||||
LOG.warning(
|
||||
str(
|
||||
PendingDeprecationWarning(
|
||||
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
LOG.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if cfg.load_4bit:
|
||||
raise ValueError(
|
||||
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||
@@ -30,17 +59,86 @@ def validate_config(cfg):
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.relora_steps and cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
logging.warning(
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
)
|
||||
|
||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".")[0]) < 2:
|
||||
LOG.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
LOG.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adamw w bf16
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@@ -7,9 +9,13 @@ def setup_wandb_env_vars(cfg):
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
cfg.use_wandb = True
|
||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
|
||||
"input": "Words: ['Hello', 'world'].",
|
||||
"output": "['world', 'Hello']"
|
||||
},
|
||||
{
|
||||
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
|
||||
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
|
||||
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
|
||||
}
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user