Compare commits
205 Commits
v0.2.0
...
flan-no-bo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e91fed495a | ||
|
|
756dfba97b | ||
|
|
91ab0592af | ||
|
|
0aeb7c7802 | ||
|
|
d35278aaf1 | ||
|
|
9492d4ebb7 | ||
|
|
ad5ca4f734 | ||
|
|
cb9d3af5c0 | ||
|
|
c969f0a9dc | ||
|
|
6d0ee4ba34 | ||
|
|
a81f52d575 | ||
|
|
1925eaf1e6 | ||
|
|
1ab3bf3e67 | ||
|
|
d7635b7148 | ||
|
|
88e17ffc50 | ||
|
|
baed440fa1 | ||
|
|
7925ddce86 | ||
|
|
6f849809c5 | ||
|
|
c16644d05e | ||
|
|
945c4191a3 | ||
|
|
136522f9c9 | ||
|
|
556fe408b3 | ||
|
|
16bb6276a5 | ||
|
|
06674a11f2 | ||
|
|
3513885f43 | ||
|
|
06652c1c39 | ||
|
|
068fc48978 | ||
|
|
aaadacf6b3 | ||
|
|
5ff547dc70 | ||
|
|
dc77c8ebce | ||
|
|
51a4c12242 | ||
|
|
4b43a66a0b | ||
|
|
34ae69989f | ||
|
|
fd2c9814c9 | ||
|
|
2ba4ae8f46 | ||
|
|
93dacba228 | ||
|
|
8002ffb41f | ||
|
|
74ef5cc083 | ||
|
|
5e616d91c0 | ||
|
|
94f310c7a6 | ||
|
|
8e568bbdae | ||
|
|
e21dab49fd | ||
|
|
52cde69288 | ||
|
|
9a58e99e81 | ||
|
|
c7dee56b87 | ||
|
|
aac4b7691e | ||
|
|
f31a338cbb | ||
|
|
4cd1deeef2 | ||
|
|
9ac16ed8d1 | ||
|
|
6b3f509d9e | ||
|
|
336aa3fd48 | ||
|
|
d0d7eaa4f3 | ||
|
|
a6ebf57e82 | ||
|
|
280832cec2 | ||
|
|
a43bae9ff0 | ||
|
|
effbbf6dd1 | ||
|
|
c9a149f9e8 | ||
|
|
c530e4b9c8 | ||
|
|
f620706776 | ||
|
|
77762a5d6b | ||
|
|
14668fa54e | ||
|
|
b565ecf0a1 | ||
|
|
fe0b76854e | ||
|
|
e944311442 | ||
|
|
e3e7b52a5b | ||
|
|
974dc00a7d | ||
|
|
572d1141e6 | ||
|
|
a6190c8094 | ||
|
|
563b6d89e6 | ||
|
|
cd0a6f6027 | ||
|
|
0e664a5ebc | ||
|
|
dd7d16d2eb | ||
|
|
e285e24f7f | ||
|
|
919727b4d7 | ||
|
|
5ffefee37f | ||
|
|
d9f713e4e3 | ||
|
|
958da70376 | ||
|
|
c4e4f8115c | ||
|
|
a808bf913f | ||
|
|
01248253a3 | ||
|
|
759e8673ce | ||
|
|
0c6f928601 | ||
|
|
eea2731a5e | ||
|
|
1db46a9c72 | ||
|
|
ab5cd28acf | ||
|
|
1a82082e91 | ||
|
|
1210dc8fd5 | ||
|
|
488a67d75a | ||
|
|
71a43f8479 | ||
|
|
39619028a3 | ||
|
|
8792199799 | ||
|
|
1edc30c786 | ||
|
|
14163c15d9 | ||
|
|
41e4f6ca31 | ||
|
|
79e2a6f140 | ||
|
|
c2508987a6 | ||
|
|
215d775147 | ||
|
|
f36e227eaf | ||
|
|
5878bb1f3a | ||
|
|
a03a7d7d8b | ||
|
|
fec6bcc3e6 | ||
|
|
931e606459 | ||
|
|
7f09106437 | ||
|
|
6b50200234 | ||
|
|
16f9e28048 | ||
|
|
b9083a7fc1 | ||
|
|
aefb2fc681 | ||
|
|
b5aa8d854c | ||
|
|
4d6490bce2 | ||
|
|
b242b69e10 | ||
|
|
320beb20f4 | ||
|
|
bd3b537344 | ||
|
|
813cfa4c14 | ||
|
|
2e13ceff37 | ||
|
|
2a801b001a | ||
|
|
e44c9e0b3e | ||
|
|
55b8542de8 | ||
|
|
febe902517 | ||
|
|
f4df266842 | ||
|
|
281dc3df59 | ||
|
|
2ef4634d45 | ||
|
|
7eae90333e | ||
|
|
c8242de725 | ||
|
|
2cfe9e9b16 | ||
|
|
79a8f52181 | ||
|
|
afaa0d2c01 | ||
|
|
bfd27ba55e | ||
|
|
babf0fdb71 | ||
|
|
a52f4816b0 | ||
|
|
81911d112c | ||
|
|
52765ac588 | ||
|
|
73e9ea4069 | ||
|
|
f8d379883d | ||
|
|
04a1b77307 | ||
|
|
2097a09d2d | ||
|
|
cfff94b123 | ||
|
|
2b222de5b6 | ||
|
|
df9528f865 | ||
|
|
193c73bce0 | ||
|
|
6abfd87d44 | ||
|
|
59bb2197ed | ||
|
|
9a02e7e1ff | ||
|
|
5b33e295bd | ||
|
|
4ac9e251b7 | ||
|
|
c9c050316f | ||
|
|
ca11ae9689 | ||
|
|
328c3bce96 | ||
|
|
5cd2126439 | ||
|
|
12620f3089 | ||
|
|
4ab0c8b201 | ||
|
|
74ebbf4371 | ||
|
|
76a70fd739 | ||
|
|
618816d4df | ||
|
|
91992cb8f5 | ||
|
|
84169d15b3 | ||
|
|
ecfe8d0a1a | ||
|
|
eee44a3b47 | ||
|
|
078a43eef8 | ||
|
|
33e1890086 | ||
|
|
1c38253692 | ||
|
|
496b83f778 | ||
|
|
ff68a95781 | ||
|
|
fb3d40f197 | ||
|
|
288fd62431 | ||
|
|
3c71c8debe | ||
|
|
a6f5e5eaec | ||
|
|
5a631b305b | ||
|
|
f94dd626f0 | ||
|
|
5079753b7a | ||
|
|
0136f510f2 | ||
|
|
72bf8aafb6 | ||
|
|
8afb0fbaba | ||
|
|
9b8585dc70 | ||
|
|
8eb5811d4e | ||
|
|
e0011fdf55 | ||
|
|
6e9e98720e | ||
|
|
c2a0792680 | ||
|
|
b267d24a2b | ||
|
|
5c3f5db38b | ||
|
|
e3d03745ba | ||
|
|
fac46002d4 | ||
|
|
33d40179ba | ||
|
|
dcb03d6da4 | ||
|
|
0e4be625ae | ||
|
|
bdc4bd7d4e | ||
|
|
2d0ba3b818 | ||
|
|
c7021e191f | ||
|
|
c56818b119 | ||
|
|
2675fb756e | ||
|
|
1076bcbbca | ||
|
|
2daa6835f0 | ||
|
|
e3c494ca7b | ||
|
|
ad0ea6aaab | ||
|
|
876edd83d0 | ||
|
|
6cb2310592 | ||
|
|
6fa40bf8ad | ||
|
|
3aad5f3b3e | ||
|
|
39a208c2bc | ||
|
|
2520ecd6df | ||
|
|
c5b0af1a7e | ||
|
|
988aeb9c34 | ||
|
|
cf61f14bff | ||
|
|
0abcd71a85 | ||
|
|
c43c5c84ff | ||
|
|
36ec6e1a0e |
13
.github/workflows/base.yml
vendored
13
.github/workflows/base.yml
vendored
@@ -12,17 +12,27 @@ jobs:
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: self-hosted
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: "117"
|
||||
cuda_version: 11.7.0
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
steps:
|
||||
@@ -46,12 +56,13 @@ jobs:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTHON_VERSION=${{ matrix.python_version }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}
|
||||
|
||||
25
.github/workflows/main.yml
vendored
25
.github/workflows/main.yml
vendored
@@ -11,18 +11,27 @@ jobs:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
runs-on: self-hosted
|
||||
@@ -46,10 +55,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -62,14 +71,22 @@ jobs:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
runs-on: self-hosted
|
||||
@@ -93,10 +110,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
file: ./docker/Dockerfile-runpod
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -7,6 +7,7 @@ jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.9", "3.10"]
|
||||
timeout-minutes: 10
|
||||
|
||||
@@ -5,6 +5,9 @@ exclude = venv
|
||||
[mypy-alpaca_lora_4bit.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-flash_attn.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -31,3 +34,6 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-addict]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-xformers.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
3
FAQS.md
3
FAQS.md
@@ -2,3 +2,6 @@
|
||||
|
||||
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
|
||||
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
|
||||
- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
|
||||
`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
|
||||
This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
|
||||
|
||||
225
README.md
225
README.md
@@ -16,31 +16,33 @@
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
|
||||
|---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❓ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
**Requirements**: Python 3.9.
|
||||
**Requirements**: Python 3.9 and Pytorch 2.0.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
accelerate config
|
||||
|
||||
# finetune lora
|
||||
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
--inference --lora_model_dir="./lora-out"
|
||||
```
|
||||
|
||||
@@ -50,18 +52,83 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
|
||||
- `winglian/axolotl:dev`: dev branch (not usually up to date)
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
```sh
|
||||
docker compose up -d
|
||||
```
|
||||
- `winglian/axolotl:dev`: dev branch
|
||||
- `winglian/axolotl-runpod:main`: for runpod
|
||||
|
||||
- Conda/Pip venv
|
||||
1. Install python **3.9**
|
||||
|
||||
2. Install python dependencies with ONE of the following:
|
||||
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
|
||||
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
|
||||
- `pip3 install -e .[gptq_triton]`
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||
```bash
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
- gptq/int4 support, NO QLoRA
|
||||
```bash
|
||||
pip3 install -e .[gptq]
|
||||
```
|
||||
- same as above but not recommended
|
||||
```bash
|
||||
pip3 install -e .[gptq_triton]
|
||||
```
|
||||
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
1. Install python
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y python3.9
|
||||
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
||||
sudo update-alternatives --config python # pick 3.9 if given option
|
||||
python -V # should be 3.9
|
||||
|
||||
```
|
||||
|
||||
2. Install pip
|
||||
```bash
|
||||
wget https://bootstrap.pypa.io/get-pip.py
|
||||
python get-pip.py
|
||||
```
|
||||
|
||||
3. Install torch
|
||||
```bash
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
4. Axolotl
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install -e . # change depend on needs
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U requests
|
||||
pip3 install -U --ignore-installed psutil
|
||||
pip3 install -U scipy
|
||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
||||
```
|
||||
|
||||
5. Set path
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||
```
|
||||
</details>
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -71,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations
|
||||
- `sharegpt:chat`: conversations
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -112,13 +179,66 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"article": "...", "summary": "..."}
|
||||
```
|
||||
|
||||
> Have some new format to propose? Check if it's already defined in [data.py](src/axolotl/utils/data.py) in `dev` branch!
|
||||
- `alpaca_chat`: basic instruct for alpaca chat
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "response": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_qa`: question and answer for alpaca chat
|
||||
```json
|
||||
{"question": "...", "answer": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_concise`: question and answer for alpaca chat, for concise answers
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "response": "..."}
|
||||
```
|
||||
- `alpaca_chat.load_camel_ai`: question and answer for alpaca chat, for load_camel_ai
|
||||
```json
|
||||
{"message_1": "...", "message_2": "..."}
|
||||
```
|
||||
- `context_qa`: in context question answering from an article
|
||||
```json
|
||||
{"article": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||
```json
|
||||
{"article": "...", "unanswerable_question": "..."}
|
||||
```
|
||||
- `creative_acr.load_answer`: instruction and revision
|
||||
```json
|
||||
{"instruction": "...", "revision": "..."}
|
||||
```
|
||||
- `creative_acr.load_critique`: critique
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "..."}
|
||||
```
|
||||
- `creative_acr.load_revise`: critique and revise
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
|
||||
```
|
||||
- `pygmalion`: pygmalion
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
```json
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type.
|
||||
|
||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||
|
||||
|
||||
|
||||
### Config
|
||||
|
||||
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||
@@ -144,6 +264,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
||||
bf16: true # require >=ampere
|
||||
fp16: true
|
||||
tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
||||
float16: true # use instead of fp16 when you don't want AMP
|
||||
```
|
||||
Note: Repo does not do 4-bit quantization.
|
||||
|
||||
@@ -171,6 +293,9 @@ base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
@@ -260,20 +385,22 @@ wandb_log_model: # 'checkpoint'
|
||||
output_dir: ./completed-model
|
||||
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
logging_steps:
|
||||
save_steps:
|
||||
eval_steps:
|
||||
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
@@ -295,11 +422,27 @@ log_sweep_max_lr:
|
||||
optimizer:
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention: # require a100 for llama
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# llama only
|
||||
xpos_rope:
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -367,11 +510,16 @@ Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
--inference --lora_model_dir ./completed-model
|
||||
--inference --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
--inference --base_model ./completed-model
|
||||
--inference --base_model="./completed-model"
|
||||
```
|
||||
- Full weights finetune w/ a prompt from a text file:
|
||||
```bash
|
||||
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
||||
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
||||
```
|
||||
|
||||
### Merge LORA to base
|
||||
@@ -382,6 +530,12 @@ Add below flag to train command above
|
||||
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
||||
```
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
> Cuda out of memory
|
||||
@@ -389,6 +543,7 @@ Add below flag to train command above
|
||||
Please reduce any below
|
||||
- `micro_batch_size`
|
||||
- `eval_batch_size`
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
> RuntimeError: expected scalar type Float but found Half
|
||||
@@ -399,10 +554,30 @@ Try set `fp16: true`
|
||||
|
||||
Try to turn off xformers.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||
|
||||
```markdown
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
```
|
||||
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
|
||||
## Community Showcase
|
||||
|
||||
Open Access AI Collective
|
||||
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
|
||||
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
||||
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
||||
|
||||
PocketDoc Labs
|
||||
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
||||
|
||||
## Contributing 🤝
|
||||
|
||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: 'NO'
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,40 +0,0 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- c_attn
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: pythia-1.4b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 32
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.0003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,41 +0,0 @@
|
||||
base_model: facebook/galactica-1.3b
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 32
|
||||
micro_batch_size: 16
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
tf32: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,39 +0,0 @@
|
||||
base_model: huggyllama/llama-13b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
||||
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
||||
type: sharegpt
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.002
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-13b-sharegpt
|
||||
batch_size: 64
|
||||
micro_batch_size: 2
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
eval_steps:
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience: 5
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,44 +0,0 @@
|
||||
base_model: huggyllama/llama-65b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
||||
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-65b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
micro_batch_size: 16
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca # original alpaca dataset
|
||||
type: alpaca
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-test
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
load_4bit: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
@@ -1,41 +0,0 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
micro_batch_size: 16
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca # original alpaca dataset
|
||||
type: alpaca
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
max_packed_sequence_len: 1024
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-test
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
gptq: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
@@ -1,86 +0,0 @@
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: decapoda-research/llama-7b-hf-int4
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: decapoda-research/llama-7b-hf
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# whether you are training a 4-bit quantized model
|
||||
load_4bit: true
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# this can be either a hf dataset, or relative path
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
|
||||
val_set_size: 0.04
|
||||
# if you want to use lora, leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
lora_model_dir:
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
max_packed_sequence_len: 1024
|
||||
# lora hyperparameters
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
# wandb configuration if your're using it
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
# where to save the finsihed model to
|
||||
output_dir: ./completed-model
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
# Use CUDA bf16
|
||||
bf16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
gradient_checkpointing: false
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
# specify a scheduler to use with the optimizer. only one_cycle is supported currently
|
||||
lr_scheduler:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention:
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
@@ -1,56 +0,0 @@
|
||||
base_model: stabilityai/stablelm-base-alpha-3b
|
||||
base_model_config: stabilityai/stablelm-base-alpha-3b
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
max_packed_sequence_len: 4096
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: stable-alpaca-3b
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./stable-alpaca-3b
|
||||
batch_size: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 100
|
||||
eval_steps: 50
|
||||
save_steps: 200
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.01
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
#tokens:
|
||||
# pad_token: "[PAD]"
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
||||
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_4bit: true
|
||||
gptq_groupsize: 128
|
||||
gptq_model_v1: false
|
||||
datasets:
|
||||
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
|
||||
- path: data/alpaca_reflect_pruned.jsonl
|
||||
type: reflection
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-reflect
|
||||
batch_size: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
flash_attention: true
|
||||
@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
|
||||
## Convert the JSON data files to JSONL.
|
||||
|
||||
```shell
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
```
|
||||
---
|
||||
|
||||
|
||||
20
docker-compose.yaml
Normal file
20
docker-compose.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
# version: '3.8'
|
||||
services:
|
||||
axolotl:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./docker/Dockerfile
|
||||
volumes:
|
||||
- .:/workspace/axolotl
|
||||
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
||||
# set environment variables
|
||||
environment:
|
||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# count: 1
|
||||
capabilities: [gpu]
|
||||
command: tail -f /dev/null
|
||||
@@ -13,8 +13,7 @@ RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/pe
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main"
|
||||
|
||||
RUN mkdir axolotl
|
||||
COPY . axolotl/
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN cd axolotl && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
|
||||
@@ -52,6 +52,8 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
|
||||
60
examples/cerebras/qlora.yml
Normal file
60
examples/cerebras/qlora.yml
Normal file
@@ -0,0 +1,60 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- c_fc
|
||||
- c_attn
|
||||
- c_proj
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -23,7 +23,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: falcon-7b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
92
examples/falcon/config-7b-qlora.yml
Normal file
92
examples/falcon/config-7b-qlora.yml
Normal file
@@ -0,0 +1,92 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
gptq: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: QingyiSi/Alpaca-CoT
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
# hyperparameters from QLoRA paper Appendix B.2
|
||||
# "We find hyperparameters to be largely robust across datasets"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
# 0.1 for models up to 13B
|
||||
# 0.05 for 33B and 65B models
|
||||
lora_dropout: 0.05
|
||||
# add LoRA modules on all linear layers of the base model
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
# QLoRA paper Table 9
|
||||
# - 16 for 7b & 13b
|
||||
# - 32 for 33b, 64 for 64b
|
||||
# Max size tested on A6000
|
||||
# - 7b: 40
|
||||
# - 40b: 4
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
# QLoRA paper Table 9
|
||||
# - 2e-4 for 7b & 13b
|
||||
# - 1e-4 for 33b & 64b
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
@@ -23,7 +23,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project: falcon-7b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
57
examples/gptj/qlora.yml
Normal file
57
examples/gptj/qlora.yml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -3,6 +3,6 @@
|
||||
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
|
||||
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
|
||||
|
||||
```
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-7b-lora-int4
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
@@ -7,30 +7,28 @@ datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
type: jeopardy
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
sequence_len: 512
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: jeopardy-bot-7b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
learning_rate: 0.00003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
@@ -48,11 +46,10 @@ eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -24,7 +24,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
16
examples/openllama-3b/README.md
Normal file
16
examples/openllama-3b/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# openllama-3b
|
||||
|
||||
Basic full tune
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
|
||||
```
|
||||
|
||||
LoRA
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
QLoRA
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
|
||||
```
|
||||
62
examples/openllama-3b/config.yml
Normal file
62
examples/openllama-3b/config.yml
Normal file
@@ -0,0 +1,62 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
float16: true
|
||||
bf16: false
|
||||
fp16: false
|
||||
tf32: false
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model_config: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -49,7 +49,7 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model_config: openlm-research/open_llama_3b_600bt_preview
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
9
examples/pythia-12b/README.md
Normal file
9
examples/pythia-12b/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Pythia 12B
|
||||
|
||||
- Single-GPU A100 only (?)
|
||||
|
||||
```shell
|
||||
python scripts/finetune.py examples/pythia-12b/config.yml
|
||||
```
|
||||
|
||||
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
|
||||
@@ -1,39 +1,49 @@
|
||||
base_model: EleutherAI/gpt-neox-20b
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
device_map: auto
|
||||
datasets:
|
||||
- path: nomic-ai/gpt4all-j-prompt-generations
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
shards: 4
|
||||
shards_index: 0
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0.0
|
||||
lora_target_modules:
|
||||
- query_key_value
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project: gpt4all-neox-20b
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./gpt4all-neox-20b
|
||||
batch_size: 48
|
||||
micro_batch_size: 4
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
lr_scheduler: one_cycle
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
bf16: false
|
||||
fp16: false
|
||||
float16: true
|
||||
tf32: true
|
||||
flash_optimum: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
gradient_checkpointing: true
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
collator_pad_to_longest: true
|
||||
@@ -1,36 +1,29 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: data/alpaca_data_gpt4.jsonl
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
- path: data/vicuna_cleaned.jsonl
|
||||
type: sharegpt
|
||||
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
lora_r: 8
|
||||
sequence_len: 512
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- query_key_value
|
||||
# - xxx
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project: pythia-1.4b-lora
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 48
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
@@ -39,3 +32,6 @@ tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 20
|
||||
logging_steps: 1
|
||||
@@ -1,6 +0,0 @@
|
||||
# qlora-openllama-3b
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/qlora-openllama-3b/config.yml
|
||||
|
||||
```
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: GPTNeoXTokenizer
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
|
||||
BIN
image/axolotl-badge-web.png
Normal file
BIN
image/axolotl-badge-web.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
@@ -1,6 +1,7 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.39.0
|
||||
accelerate
|
||||
addict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
@@ -10,6 +11,7 @@ sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers
|
||||
optimum
|
||||
# qlora things
|
||||
bert-score==0.3.13
|
||||
evaluate==0.4.0
|
||||
|
||||
@@ -13,11 +13,13 @@ import fire
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.validation import validate_config
|
||||
@@ -46,10 +48,11 @@ def choose_device(cfg):
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device == "cuda":
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
@@ -61,38 +64,68 @@ def get_multi_line_input() -> Optional[str]:
|
||||
return instruction
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=100,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@@ -142,24 +175,30 @@ def train(
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or cfg.strict is False:
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.gradient_accumulation_steps = (
|
||||
cfg.gradient_accumulation_steps // cfg.world_size
|
||||
)
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
@@ -168,18 +207,31 @@ def train(
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
validate_config(cfg)
|
||||
if cfg.tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# load the tokenizer first
|
||||
logging.info("loading tokenizer...")
|
||||
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
|
||||
if check_not_in(
|
||||
["inference", "shard", "merge_lora"], kwargs
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
if not cfg.pretraining_dataset:
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
train_dataset = load_pretraining_dataset(
|
||||
cfg.pretraining_dataset,
|
||||
tokenizer,
|
||||
max_tokens=cfg.sequence_len,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
logging.info("check_dataset_labels...")
|
||||
@@ -203,7 +255,6 @@ def train(
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
inference=("inference" in kwargs),
|
||||
)
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
@@ -216,9 +267,15 @@ def train(
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if "inference" in kwargs:
|
||||
if cfg.inference:
|
||||
logging.info("calling do_inference function")
|
||||
do_inference(cfg, model, tokenizer)
|
||||
prompter: Optional[str] = "AlpacaPrompter"
|
||||
if "prompter" in kwargs:
|
||||
if kwargs["prompter"] == "None":
|
||||
prompter = None
|
||||
else:
|
||||
prompter = kwargs["prompter"]
|
||||
do_inference(cfg, model, tokenizer, prompter=prompter)
|
||||
return
|
||||
|
||||
if "shard" in kwargs:
|
||||
@@ -240,12 +297,15 @@ def train(
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
signal.SIGINT,
|
||||
lambda signal, frame: (
|
||||
model.save_pretrained(cfg.output_dir),
|
||||
sys.exit(0),
|
||||
),
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
logging.info("Starting trainer...")
|
||||
@@ -265,13 +325,24 @@ def train(
|
||||
logging.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
@@ -33,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
count = 0
|
||||
# Loop through the entire dataset
|
||||
for example in iterator:
|
||||
try:
|
||||
yield self.prompt_tokenizer.tokenize_prompt(example)
|
||||
count += 1
|
||||
except InvalidDataException:
|
||||
pass
|
||||
if count == 0:
|
||||
raise RuntimeError("Expected at least one datapoint in dataset.")
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
@@ -127,6 +131,11 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
if (
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
):
|
||||
attention_mask[0] = 0
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
|
||||
@@ -25,6 +25,7 @@ def forward(
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
|
||||
233
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
233
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
logging.error("xformers not found! Please install it before trying to use it.")
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.models.llama.modeling_llama
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class XposRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scale_base=2048,
|
||||
use_xpos=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.scale_base = scale_base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer("scale", None)
|
||||
self.register_buffer("scale_cached", torch.ones(1))
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
||||
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
seq_len,
|
||||
):
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
||||
self.inv_freq
|
||||
)
|
||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
||||
|
||||
self.register_buffer("freqs_cached", freqs)
|
||||
|
||||
if self.scale is None:
|
||||
self.register_buffer(
|
||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
||||
)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
||||
|
||||
power = (t - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
||||
self.register_buffer("scale_cached", scale)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
||||
freqs = freqs[position_ids, :]
|
||||
if scale.shape[-1] != 1:
|
||||
scale = scale[position_ids, :]
|
||||
|
||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def replace_llama_rope_with_xpos_rope():
|
||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
@@ -18,6 +18,40 @@ def load(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
class AlpacaConcisePrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
|
||||
class AlpacaChatPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
self.prompt_style = PromptStyle.CHAT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
|
||||
class NoSystemPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Null Prompter with no system prompts
|
||||
"""
|
||||
|
||||
prompt_input = "{instruction} {input} "
|
||||
prompt_no_input = "{instruction} "
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for AlpacaQA
|
||||
@@ -31,9 +65,40 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for CamelAI datasets
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["message_1"],
|
||||
"",
|
||||
prompt["message_2"],
|
||||
)
|
||||
|
||||
|
||||
def load_concise(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaConcisePrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_camel_ai(tokenizer, cfg):
|
||||
return CamelAIPromptTokenizingStrategy(
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
67
src/axolotl/prompt_strategies/context_qa.py
Normal file
67
src/axolotl/prompt_strategies/context_qa.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
# article, unanswerable_question, question, answer
|
||||
def load_404(tokenizer, cfg):
|
||||
return AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
AlpacaContextPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaContextPromptTokenizingStrategy(
|
||||
AlpacaContextPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"Use the following contextual information to concisely answer the question.\n"
|
||||
)
|
||||
system_no_input_prompt = (
|
||||
"Use the following contextual information to concisely answer the question.\n"
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question and answer
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"] + "\n===\n" + prompt["question"],
|
||||
"",
|
||||
prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
InstructionPromptTokenizingStrategy
|
||||
):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question that can't be answered
|
||||
from the context and a default response to that effect
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
|
||||
"",
|
||||
"The context provided does not contain any information about your inquiry. "
|
||||
"Therefore, I'm unable to answer your question based on the given context.",
|
||||
)
|
||||
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
||||
"""
|
||||
|
||||
# title, text, explanation
|
||||
def get_conversation_thread(self, prompt):
|
||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
||||
return [
|
||||
{"from": "human", "value": "Tell me a joke."},
|
||||
{"from": "gpt", "value": title + prompt["text"]},
|
||||
{"from": "human", "value": "Why is that joke funny?"},
|
||||
{"from": "gpt", "value": prompt["explanation"]},
|
||||
]
|
||||
67
src/axolotl/prompt_strategies/sharegpt_simple.py
Normal file
67
src/axolotl/prompt_strategies/sharegpt_simple.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_guanaco(tokenizer, cfg):
|
||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
return turns
|
||||
|
||||
|
||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
@@ -73,8 +73,17 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
elif ( # some tokenizers automatically add an eos token, let's remove it
|
||||
not add_eos_token and result["input_ids"][-1] == self.tokenizer.eos_token_id
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][:-1]
|
||||
result["attention_mask"] = result["attention_mask"][:-1]
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
if (
|
||||
self.tokenizer.bos_token_id
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
@@ -96,25 +105,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_full_prompt
|
||||
return tokenized_prompt
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response # pylint: disable=redefined-builtin
|
||||
@@ -410,7 +421,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
if (
|
||||
self.tokenizer.bos_token_id
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
|
||||
@@ -261,28 +261,33 @@ class Conversation:
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
conv_vicuna_v1_1 = Conversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_style=None):
|
||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
||||
if prompt_style != PromptStyle.CHAT.value:
|
||||
raise ValueError(
|
||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||
)
|
||||
system: str = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
)
|
||||
self._conversation = Conversation(
|
||||
system=system,
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
# def match_prompt_style(self):
|
||||
# if self.prompt_style == PromptStyle.chat.value:
|
||||
@@ -300,7 +305,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
|
||||
conv = conv_vicuna_v1_1.copy()
|
||||
conv = self._conversation.copy()
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
try:
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
import os
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the BetterTransformer wrapped model"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
if control.should_save:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
model = BetterTransformer.reverse(kwargs["model"])
|
||||
model.save_pretrained(checkpoint_folder)
|
||||
# FIXME - need to cleanup old checkpoints
|
||||
|
||||
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Module containing data utilities"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -78,6 +79,13 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||
logging.info("Loading raw datasets...")
|
||||
|
||||
if cfg.seed:
|
||||
seed = cfg.seed
|
||||
else:
|
||||
logging.info("No seed provided, using default seed of 42")
|
||||
seed = 42
|
||||
|
||||
datasets = []
|
||||
# pylint: disable=invalid-name
|
||||
for d in cfg.datasets:
|
||||
@@ -127,11 +135,11 @@ def load_tokenized_prepared_datasets(
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=42)["train"].shard(
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -232,13 +240,21 @@ def load_tokenized_prepared_datasets(
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
suffix = ""
|
||||
if ":load_" in d.type:
|
||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||
logging.error(
|
||||
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
||||
|
||||
samples: List[int] = []
|
||||
for d in datasets:
|
||||
samples = samples + list(d)
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
||||
@@ -379,8 +395,127 @@ def load_prepare_datasets(
|
||||
index=cfg.dataset_shard_idx,
|
||||
)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
if cfg.val_set_size:
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
else:
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(tokenizer, max_tokens, examples):
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
# Convert to PyTorch tensors
|
||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||
new_input_ids = []
|
||||
new_attention_mask = []
|
||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||
for i, _ in enumerate(input_ids):
|
||||
input_ids[i] = torch.cat(
|
||||
(
|
||||
input_ids[i],
|
||||
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||
|
||||
# Concatenate tokens so that their lengths are less than max_tokens
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
for ids, mask in zip(input_ids, attention_mask):
|
||||
if buffer_input_ids.numel() == max_tokens:
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
else:
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
|
||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
|
||||
ret = {
|
||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||
"labels": [seq.tolist() for seq in new_input_ids],
|
||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||
}
|
||||
|
||||
logging.debug(len(ret["input_ids"]))
|
||||
return ret
|
||||
|
||||
|
||||
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
dataset = load_dataset(path, streaming=True, split="train")
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
# TODO dynamically figure out which columns/features to remove
|
||||
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
||||
return dataset
|
||||
|
||||
@@ -10,39 +10,38 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM # noqa: F401
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
try:
|
||||
from transformers import LlamaForCausalLM
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
)
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftConfig # noqa: F401
|
||||
from transformers import PreTrainedTokenizer # noqa: F401
|
||||
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
|
||||
@@ -71,42 +70,62 @@ def load_tokenizer(
|
||||
|
||||
|
||||
def load_model(
|
||||
base_model,
|
||||
base_model_config,
|
||||
model_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter="lora",
|
||||
inference=False,
|
||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
|
||||
if is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
if cfg.bf16:
|
||||
logging.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
MEM_TOKEN,
|
||||
patch_llama_with_landmark_attn,
|
||||
)
|
||||
|
||||
logging.info("patching with landmark attention")
|
||||
patch_llama_with_landmark_attn()
|
||||
|
||||
# Note: This might overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
)
|
||||
|
||||
logging.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16:
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
@@ -117,11 +136,18 @@ def load_model(
|
||||
)
|
||||
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
from peft import prepare_model_for_int8_training
|
||||
except Exception as err:
|
||||
logging.exception(err)
|
||||
raise err
|
||||
|
||||
try:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
except ImportError:
|
||||
# For backward compatibility
|
||||
from peft import (
|
||||
prepare_model_for_int8_training as prepare_model_for_kbit_training,
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
@@ -133,7 +159,7 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
try:
|
||||
if cfg.gptq and is_llama_derived_model:
|
||||
if cfg.gptq and cfg.is_llama_derived_model:
|
||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -171,9 +197,13 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
elif cfg.is_llama_derived_model:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -221,6 +251,22 @@ def load_model(
|
||||
base_model,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
):
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
@@ -239,6 +285,7 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -248,12 +295,24 @@ def load_model(
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings
|
||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||
):
|
||||
logging.warning(
|
||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, adapter)
|
||||
|
||||
@@ -291,6 +350,9 @@ def load_model(
|
||||
logging.warning("there are no parameters that require gradient updates")
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
@@ -323,7 +385,6 @@ def load_llama_adapter(model, cfg):
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
device_map=cfg.device_map,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
else:
|
||||
@@ -385,8 +446,7 @@ def load_lora(model, cfg):
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
device_map=cfg.device_map,
|
||||
# torch_dtype=torch.float16,
|
||||
is_trainable=not cfg.inference,
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
@@ -15,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback, Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||
from axolotl.utils.callbacks import (
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||
|
||||
|
||||
@@ -62,8 +66,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
save_steps = cfg.save_steps
|
||||
eval_steps = cfg.eval_steps
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
@@ -74,6 +76,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if cfg.seed:
|
||||
training_arguments_kwargs["seed"] = cfg.seed
|
||||
|
||||
if cfg.gradient_checkpointing:
|
||||
if cfg.gptq:
|
||||
from alpaca_lora_4bit.gradient_checkpointing import (
|
||||
@@ -109,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
# TODO search Path("./") for one
|
||||
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
||||
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
||||
if cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
||||
if cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
||||
|
||||
training_args = transformers.TrainingArguments(
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
@@ -119,16 +134,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||
save_strategy="steps" if save_steps else "epoch",
|
||||
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=save_steps,
|
||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=(
|
||||
cfg.load_best_model_at_end is not False
|
||||
and cfg.val_set_size > 0
|
||||
and save_steps
|
||||
and save_steps % eval_steps == 0
|
||||
and cfg.save_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
)
|
||||
or False,
|
||||
@@ -225,6 +240,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
]: # only save in rank 0
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True,
|
||||
}
|
||||
@@ -233,6 +251,26 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
|
||||
logging.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
|
||||
@@ -2,8 +2,20 @@
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
logging.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if cfg.load_4bit:
|
||||
raise ValueError(
|
||||
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||
@@ -44,7 +56,50 @@ def validate_config(cfg):
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
logging.warning(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
)
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||
logging.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".")[0]) < 2:
|
||||
logging.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
logging.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
|
||||
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adamw w bf16
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -15,3 +15,5 @@ def setup_wandb_env_vars(cfg):
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
|
||||
"input": "Words: ['Hello', 'world'].",
|
||||
"output": "['world', 'Hello']"
|
||||
},
|
||||
{
|
||||
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
|
||||
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
|
||||
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
|
||||
}
|
||||
]
|
||||
65
tests/test_packed_dataset.py
Normal file
65
tests/test_packed_dataset.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Module for testing dataset sequence packing"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing dataset sequences
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_resets_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
dateset = load_dataset(
|
||||
"json",
|
||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
||||
)["train"]
|
||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
self.tokenizer,
|
||||
[dataset],
|
||||
seq_length=2048,
|
||||
)
|
||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
example = packed_dataset[0]
|
||||
next_bos_index = (
|
||||
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
||||
) # add one since we sliced
|
||||
|
||||
# first example doesn't have mask reset
|
||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][0] == 1
|
||||
|
||||
# but subsequent one does
|
||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][next_bos_index] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -6,8 +6,12 @@ from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompter
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
|
||||
|
||||
logging.basicConfig(level="INFO")
|
||||
|
||||
@@ -18,6 +22,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
@@ -28,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_sharegpt_integration(self):
|
||||
print(Path(__file__).parent)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
@@ -52,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
prompter = NoSystemPrompter()
|
||||
# pylint: disable=duplicate-code
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {
|
||||
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
||||
"output": "world!",
|
||||
}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
world_idx = example["input_ids"].index(3186)
|
||||
assert example["labels"][world_idx] == 3186
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
def test_alpaca(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
world_idx = example["input_ids"].index(6324)
|
||||
assert example["labels"][world_idx] == 6324
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,6 +15,12 @@ class ValidationTest(unittest.TestCase):
|
||||
Test the validation module
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def test_load_4bit_deprecate(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -23,6 +31,17 @@ class ValidationTest(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_batch_size_unused_warning(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 32,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert "batch_size is not recommended" in self._caplog.records[0].message
|
||||
|
||||
def test_qlora(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
@@ -117,3 +136,180 @@ class ValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
validate_config(cfg)
|
||||
|
||||
def test_gradient_accumulations_or_batch_size(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_falcon_fsdp(self):
|
||||
regex_exp = r".*FSDP is not supported for falcon models.*"
|
||||
|
||||
# Check for lower-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "tiiuae/falcon-7b",
|
||||
"fsdp": ["full_shard", "auto_wrap"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
# Check for upper-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Falcon-7b",
|
||||
"fsdp": ["full_shard", "auto_wrap"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "tiiuae/falcon-7b",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_mpt_gradient_checkpointing(self):
|
||||
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
||||
|
||||
# Check for lower-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "mosaicml/mpt-7b",
|
||||
"gradient_checkpointing": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_flash_optimum(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
"adapter": "lora",
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"probably set bfloat16 or float16" in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
regex_exp = r".*AMP is not supported.*"
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"flash_optimum": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
regex_exp = r".*AMP is not supported.*"
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_adamw_hyperparams(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": None,
|
||||
"adamw_epsilon": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": "adafactor",
|
||||
"adamw_beta1": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"adamw_beta1": 0.0001,
|
||||
"adamw_beta2": 0.0001,
|
||||
"adamw_epsilon": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"optimizer": "adafactor",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user