Compare commits
9 Commits
ia3-peft
...
attn-patch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
587dbbfc02 | ||
|
|
6c306d9186 | ||
|
|
7565fb9d63 | ||
|
|
a6b737d5ff | ||
|
|
cf95b57c0a | ||
|
|
13f7efaf74 | ||
|
|
d773384f74 | ||
|
|
985dcbc051 | ||
|
|
5d0b27e5a1 |
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -53,13 +53,6 @@ body:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: config
|
||||
attributes:
|
||||
label: Config yaml
|
||||
description: |
|
||||
Please attach the config yaml!
|
||||
|
||||
- type: textarea
|
||||
id: possible-solution
|
||||
attributes:
|
||||
|
||||
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -25,11 +25,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
28
.github/workflows/main.yml
vendored
28
.github/workflows/main.yml
vendored
@@ -13,22 +13,22 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 118
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -49,12 +49,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -76,10 +74,10 @@ jobs:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
16
.github/workflows/pre-commit.yml
vendored
Normal file
16
.github/workflows/pre-commit.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
45
.github/workflows/pypi.yml
vendored
45
.github/workflows/pypi.yml
vendored
@@ -1,45 +0,0 @@
|
||||
name: publish pypi
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
permissions:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel
|
||||
pip3 install -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
|
||||
- name: Update version in setup.py
|
||||
run: >-
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
|
||||
- name: Build a binary wheel
|
||||
run: >-
|
||||
python setup.py sdist bdist_wheel
|
||||
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
58
.github/workflows/tests.yml
vendored
58
.github/workflows/tests.yml
vendored
@@ -1,32 +1,10 @@
|
||||
name: Tests
|
||||
name: PyTest
|
||||
on:
|
||||
# check on push/merge to main, PRs, and manual triggers
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -46,35 +24,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install -U -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
pip install -e .
|
||||
pip install -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest --ignore=tests/e2e/ tests/
|
||||
|
||||
e2e-test:
|
||||
name: E2E Tests
|
||||
runs-on: [self-hosted, gpu]
|
||||
timeout-minutes: 20
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
# cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
run: |
|
||||
pytest tests/e2e/
|
||||
pytest tests/
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -161,7 +161,3 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# WandB
|
||||
# wandb creates a folder to store logs for training runs
|
||||
wandb
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb
|
||||
|
||||
@@ -8,9 +8,6 @@ ignore_missing_imports = True
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.phi.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-flash_attn.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -23,9 +20,6 @@ ignore_missing_imports = True
|
||||
[mypy-peft]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-wandb]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-boolean-expressions,
|
||||
|
||||
522
README.md
522
README.md
@@ -2,18 +2,6 @@
|
||||
|
||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||
|
||||
Features:
|
||||
- Train various Huggingface models such as llama, pythia, falcon, mpt
|
||||
- Supports fullfinetune, lora, qlora, relora, and gptq
|
||||
- Customize configurations using a simple yaml file or CLI overwrite
|
||||
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
|
||||
- Integrated with xformer, flash attention, rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb
|
||||
- And more!
|
||||
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
@@ -23,16 +11,13 @@ Features:
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Docker Installation](#environment)
|
||||
- [Conda/Pip venv Installation](#condapip-venv)
|
||||
- [LambdaLabs Installation](#lambdalabs)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Common Errors](#common-errors-)
|
||||
@@ -51,7 +36,7 @@ Features:
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and Axolotl questions!!
|
||||
Go ahead and axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
@@ -65,16 +50,14 @@ Features:
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -85,29 +68,28 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--peft_model_dir="./lora-out"
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
--inference --lora_model_dir="./lora-out"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Environment
|
||||
|
||||
#### Docker
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
@@ -115,23 +97,27 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
#### Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
- Conda/Pip venv
|
||||
1. Install python **3.9**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install Axolotl along with python dependencies
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
- gptq/int4 support, NO QLoRA
|
||||
```bash
|
||||
huggingface-cli login
|
||||
pip3 install -e .[gptq]
|
||||
```
|
||||
- same as above but not recommended
|
||||
```bash
|
||||
pip3 install -e .[gptq_triton]
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### LambdaLabs
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -163,10 +149,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -e . # change depend on needs
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U --ignore-installed requests Pillow psutil scipy
|
||||
pip3 install -U requests
|
||||
pip3 install -U --ignore-installed psutil
|
||||
pip3 install -U scipy
|
||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
||||
```
|
||||
|
||||
5. Set path
|
||||
@@ -175,9 +163,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||
@@ -187,7 +172,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -252,10 +237,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"article": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_v2`: in context question answering (alternate)
|
||||
```json
|
||||
{"context": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||
```json
|
||||
{"article": "...", "unanswerable_question": "..."}
|
||||
@@ -276,15 +257,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `metharme`: instruction, adds additional eos tokens
|
||||
```json
|
||||
{"prompt": "...", "generation": "..."}
|
||||
```
|
||||
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -297,28 +274,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
```json
|
||||
{"instruction": "...", "output": "..."}
|
||||
```
|
||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||
|
||||
You can use this example in your YAML config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
field_system: system
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||
|
||||
|
||||
### Config
|
||||
@@ -345,28 +305,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- path: EleutherAI/pile
|
||||
name: enron_emails
|
||||
type: completion # format from earlier
|
||||
field: text # Optional[str] default: text, field to use for completion data
|
||||
|
||||
# huggingface repo with multiple named configurations/subsets
|
||||
datasets:
|
||||
- path: bigcode/commitpackft
|
||||
name:
|
||||
- ruby
|
||||
- python
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: data.jsonl # or json
|
||||
ds_type: json # see other options below
|
||||
type: alpaca
|
||||
|
||||
# dataset with splits, but no train split
|
||||
dataset:
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
- path: json
|
||||
data_files: data.jsonl # or json
|
||||
type: alpaca # format from earlier
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -384,10 +328,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- lora
|
||||
```yaml
|
||||
adapter: lora # qlora or leave blank for full finetune
|
||||
peft_r: 8
|
||||
peft_alpha: 16
|
||||
peft_dropout: 0.05
|
||||
peft_target_modules:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
```
|
||||
@@ -397,15 +341,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
<summary>All yaml options</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
@@ -420,24 +364,18 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# Used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
gptq_model_v1: false # v1 or v2
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# Use bitsandbytes 4 bit
|
||||
# use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
@@ -447,59 +385,28 @@ fp16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true # require >=ampere
|
||||
|
||||
# No AMP (automatic mixed precision)
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
data_files: # path to source data files
|
||||
shards: # number of shards to split data into
|
||||
name: # name of dataset configuration to load
|
||||
|
||||
# Custom user prompt
|
||||
- path: repo
|
||||
type:
|
||||
# The below are defaults. only set what's needed.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_input: input
|
||||
field_output: output
|
||||
|
||||
# Customizable to be single line or multi-line
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
Assistant:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
# push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
@@ -508,38 +415,29 @@ dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# Max sequence length to concatenate training samples together up to
|
||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
peft_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
peft_r: 8
|
||||
peft_alpha: 16
|
||||
peft_dropout: 0.05
|
||||
peft_target_modules:
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
# lora hyperparameters
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
@@ -547,123 +445,69 @@ peft_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
peft_target_linear: # if true, will target all linear layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
peft_modules_to_save:
|
||||
lora_target_linear: # if true, will target all linear layers
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
# Once you complete training, the model will be saved to the following directory.
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
peft_fan_in_fan_out: false
|
||||
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id: # Set the name of your wandb run
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
# training hyperparameters
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps:
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
# specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
# for one_cycle optim
|
||||
lr_div_factor: # learning rate div factor
|
||||
|
||||
# For log_sweep optim
|
||||
# for log_sweep optim
|
||||
log_sweep_min_lr:
|
||||
log_sweep_max_lr:
|
||||
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
#
|
||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# in the examples/ for your model and fine-tuning use case.
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_hf
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
# - adagrad
|
||||
# - adamw_bnb_8bit
|
||||
# - lion_8bit
|
||||
# - lion_32bit
|
||||
# - paged_adamw_32bit
|
||||
# - paged_adamw_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
# specify optimizer
|
||||
optimizer:
|
||||
# Specify weight decay
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
@@ -672,66 +516,55 @@ adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
noisy_embedding_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
# Whether to use scaled-dot-product attention
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
# add or change special tokens
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
|
||||
# Add extra tokens.
|
||||
# add extra tokens
|
||||
tokens:
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path. e.g., deepspeed/zero3.json
|
||||
# Deepspeed config path
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
# Set padding for data collator to 'longest'
|
||||
collator_pad_to_longest:
|
||||
|
||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
pretraining_dataset:
|
||||
|
||||
@@ -747,78 +580,18 @@ strict:
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Understanding of batch size and gradient accumulation steps </summary>
|
||||
<br/>
|
||||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
|
||||
|
||||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
|
||||
|
||||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
|
||||
|
||||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
|
||||
|
||||
**Example 1:**
|
||||
Micro batch size: 3
|
||||
Gradient accumulation steps: 2
|
||||
Number of GPUs: 3
|
||||
Total batch size = 3 * 2 * 3 = 18
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|----------------|----------------|----------------|
|
||||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
|
||||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (accumulate) | → (accumulate) | → (accumulate) |
|
||||
|----------------|----------------|----------------|
|
||||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
|
||||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Micro batch size: 2
|
||||
Gradient accumulation steps: 1
|
||||
Number of GPUs: 3
|
||||
Total batch size = 2 * 1 * 3 = 6
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|-----------|-----------|-----------|
|
||||
| S1, S2 | S3, S4 | S5, S6 |
|
||||
| e1, e2 | e3, e4 | e5, e6 |
|
||||
|-----------|-----------|-----------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
```bash
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
accelerate launch scripts/finetune.py configs/your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
||||
```
|
||||
|
||||
##### Config
|
||||
@@ -834,6 +607,11 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
- llama Deepspeed
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero3.json
|
||||
```
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
- wandb options
|
||||
@@ -846,58 +624,36 @@ wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Training with Deepspeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
|
||||
--inference --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
|
||||
--inference --base_model="./completed-model"
|
||||
```
|
||||
- Full weights finetune w/ a prompt from a text file:
|
||||
```bash
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
||||
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
Add below flag to train command above
|
||||
|
||||
```bash
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
||||
```
|
||||
|
||||
## Common Errors 🧰
|
||||
@@ -910,9 +666,7 @@ Please reduce any below
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
> `failed (exitcode: -9)` usually means your system has run out of system memory.
|
||||
Similarly, you should consider reducing the same settings as when you run out of VRAM.
|
||||
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
|
||||
|
||||
@@ -928,10 +682,6 @@ Try to turn off xformers.
|
||||
|
||||
It's safe to ignore it.
|
||||
|
||||
> NCCL Timeouts during training
|
||||
|
||||
See the [NCCL](docs/nccl.md) guide.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
24
data/README.md
Normal file
24
data/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
## Download some datasets
|
||||
```shell
|
||||
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
|
||||
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
|
||||
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
|
||||
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
|
||||
```
|
||||
|
||||
## Convert the JSON data files to JSONL.
|
||||
|
||||
```shell
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||
```
|
||||
---
|
||||
|
||||
Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
|
||||
|
||||
```shell
|
||||
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
||||
```
|
||||
1
data/raw/.gitignore
vendored
Normal file
1
data/raw/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
**
|
||||
@@ -1,41 +0,0 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -35,8 +35,11 @@
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.95
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
@@ -45,11 +48,9 @@
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -9,11 +9,6 @@ services:
|
||||
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
||||
# set environment variables
|
||||
environment:
|
||||
# Set environment variables
|
||||
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
|
||||
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
|
||||
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
|
||||
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
|
||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||
deploy:
|
||||
resources:
|
||||
|
||||
@@ -5,29 +5,25 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main"
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
RUN cd axolotl && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[flash-attn]; \
|
||||
pip install -e .; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
RUN cd axolotl && \
|
||||
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
|
||||
@@ -13,14 +13,16 @@ ARG CUDA="118"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
||||
|
||||
RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
@@ -29,6 +31,26 @@ WORKDIR /workspace
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
|
||||
FROM base-builder AS flash-attn-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||
cd flash-attention && \
|
||||
git checkout v2.0.4 && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd csrc/fused_dense_lib && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd ../xentropy && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd ../rotary && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd ../layer_norm && \
|
||||
python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
@@ -37,15 +59,13 @@ WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
cd DeepSpeed && \
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
ARG MAX_JOBS="-1"
|
||||
ENV MAX_JOBS=$MAX_JOBS
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
@@ -57,6 +77,12 @@ FROM base-builder
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
# recompile apex
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
||||
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
@@ -64,8 +90,13 @@ RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
|
||||
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
# Multi Node
|
||||
|
||||
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
|
||||
|
||||
~/.cache/huggingface/accelerate/default_config.yaml
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
|
||||
main_process_ip: 10.0.0.4 # Set to main machine's IP
|
||||
main_process_port: 5000
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 2 # Change to the number of machines
|
||||
num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
Configure your model to use FSDP with for example:
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_offload_params: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Machine configuration
|
||||
|
||||
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
|
||||
|
||||
You will also need to have the same configuration file for your model on each machine.
|
||||
|
||||
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
@@ -1,51 +0,0 @@
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
X represents a padding token
|
||||
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B ]
|
||||
C C C C C C C ]
|
||||
D D D D ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F ]
|
||||
[ G G G ]
|
||||
[ H H H H ]]
|
||||
|
||||
[[ I I I ]
|
||||
[ J J J ]
|
||||
[ K K K K K]
|
||||
[ L L L ]]
|
||||
```
|
||||
|
||||
after padding to longest input in each step
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B X X X X X X ]
|
||||
C C C C C C C X X X X ]
|
||||
D D D D X X X X X X X ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F X X X X ]
|
||||
[ G G G X X X X X ]
|
||||
[ H H H H X X X X ]]
|
||||
|
||||
[[ I I I X X ]
|
||||
[ J J J X X ]
|
||||
[ K K K K K ]
|
||||
[ L L L X X ]]
|
||||
```
|
||||
|
||||
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A B B B B B
|
||||
B C C C C C C C D D D D E E E E
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
46
docs/nccl.md
46
docs/nccl.md
@@ -1,46 +0,0 @@
|
||||
# NCCL
|
||||
|
||||
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
|
||||
|
||||
```text
|
||||
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
|
||||
```
|
||||
|
||||
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
|
||||
|
||||
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
|
||||
|
||||
```shell
|
||||
nvidia-smi nvlink --status
|
||||
```
|
||||
|
||||
To force NCCL to use NVLink, simply set this in the environment:
|
||||
|
||||
```shell
|
||||
export NCCL_P2P_LEVEL=NVL
|
||||
```
|
||||
|
||||
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
|
||||
|
||||
| NCCL_P2P_LEVEL | Description |
|
||||
| -------------- | ----------- |
|
||||
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
|
||||
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
|
||||
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
|
||||
|
||||
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
|
||||
|
||||
```shell
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
|
||||
|
||||
```shell
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=ALL
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
|
||||
```
|
||||
|
||||
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.
|
||||
@@ -1,90 +0,0 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
base_model_config: cerebras/btlm-3b-8k-base
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
trust_remote_code: true
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
hf_use_auth_token: true
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
sample_packing: false
|
||||
sample_packing_eff_est:
|
||||
sample_packing_seq_len_multiplier:
|
||||
total_num_tokens:
|
||||
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.000000001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
lr_quadratic_warmup: true
|
||||
learning_rate: 0.000085
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
eval_steps:
|
||||
save_steps:
|
||||
save_total_limit:
|
||||
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
fsdp:
|
||||
# - full_shard
|
||||
# - auto_wrap
|
||||
fsdp_config:
|
||||
# fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock
|
||||
@@ -7,10 +7,10 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,22 +0,0 @@
|
||||
# Overview
|
||||
|
||||
This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
|
||||
|
||||
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
|
||||
|
||||
The 13b variant will fit if you change these settings to these values:
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
|
||||
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
|
||||
|
||||
```
|
||||
@@ -3,7 +3,6 @@ base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
@@ -12,10 +11,10 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 16
|
||||
|
||||
@@ -6,7 +6,6 @@ base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
@@ -18,11 +17,11 @@ datasets:
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
@@ -12,10 +11,10 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 64
|
||||
|
||||
@@ -7,10 +7,10 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
8
examples/gptq-lora-7b/README.md
Normal file
8
examples/gptq-lora-7b/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# LLaMa 7B using LoRA
|
||||
|
||||
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
|
||||
|
||||
```
|
||||
63
examples/gptq-lora-7b/config.yml
Normal file
63
examples/gptq-lora-7b/config.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
|
||||
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code:
|
||||
load_in_8bit: true
|
||||
gptq: true
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora-int4
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-7b-lora-int4
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
fp16: true
|
||||
bf16: false
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 5
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
gradient_checkpointing: true
|
||||
gptq_groupsize: 128
|
||||
gptq_model_v1: false
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
tokens:
|
||||
pad_token: "[PAD]"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -6,10 +6,10 @@ load_in_8bit: false
|
||||
datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
type: jeopardy
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 512
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||
base_model_config: TheBloke/Llama-2-7B-GPTQ
|
||||
is_llama_derived_model: false
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
hf_use_auth_token: true
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
sequence_len: 4096
|
||||
sample_packing:
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- k_proj
|
||||
- o_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
lr_quadratic_warmup: true
|
||||
learning_rate: 0.000017
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
fp16: false
|
||||
float16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
warmup_steps: 100
|
||||
eval_steps:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./ia3-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: ia3
|
||||
peft_model_dir:
|
||||
peft_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- down_proj
|
||||
peft_feedforward_modules:
|
||||
- down_proj
|
||||
peft_fan_in_fan_out: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 5
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,16 +11,15 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
@@ -56,8 +55,6 @@ flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,16 +11,15 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -58,7 +57,6 @@ flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
relora_steps: 150
|
||||
relora_warmup_steps: 10
|
||||
relora_cpu_offload: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,69 +0,0 @@
|
||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,12 +0,0 @@
|
||||
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
|
||||
|
||||
Fine Tune:
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
|
||||
|
||||
```
|
||||
|
||||
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
|
||||
```
|
||||
@@ -1,62 +0,0 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
base_model_config: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,79 +0,0 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
base_model_config: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -6,10 +6,10 @@ load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -9,12 +9,12 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
@@ -29,11 +29,11 @@ wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000003
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
float16: true
|
||||
@@ -45,12 +45,12 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -9,12 +9,12 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0
|
||||
@@ -33,9 +33,9 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
batch_size: 16
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -50,16 +50,16 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -9,12 +9,12 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
@@ -27,33 +27,33 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
fp16: true
|
||||
tf32: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
# Phi
|
||||
|
||||
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
|
||||
|
||||
# OR
|
||||
|
||||
python -m axolotl.cli.train examples/phi/phi-qlora.yml
|
||||
```
|
||||
@@ -1,75 +0,0 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: MixFormerSequentialForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000003
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
resize_token_embeddings_to_32x: true
|
||||
special_tokens:
|
||||
bos_token: "<|endoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
unk_token: "<|endoftext|>"
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -1,75 +0,0 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false # not CURRENTLY compatible with LoRAs
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000003
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
resize_token_embeddings_to_32x: true
|
||||
special_tokens:
|
||||
bos_token: "<|endoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
unk_token: "<|endoftext|>"
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -10,10 +10,10 @@ device_map: auto
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 64
|
||||
@@ -47,3 +47,4 @@ local_rank:
|
||||
gradient_checkpointing: true
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
collator_pad_to_longest: true
|
||||
|
||||
@@ -4,10 +4,10 @@ load_in_8bit: true
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 512
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
@@ -28,8 +28,8 @@ num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
bf16: True
|
||||
tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -7,10 +7,10 @@ load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -5,10 +5,10 @@ load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -16,11 +16,11 @@ datasets:
|
||||
data_files:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 8192
|
||||
max_packed_sequence_len:
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 370 KiB |
@@ -1,27 +1,20 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
torch==2.0.1
|
||||
auto-gptq
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||
deepspeed
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
PyYAML==6.0
|
||||
datasets
|
||||
flash-attn>=2.3.0
|
||||
accelerate>=0.19.0
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.22
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4
|
||||
numpy==1.24.4
|
||||
# qlora things
|
||||
bert-score==0.3.13
|
||||
evaluate==0.4.0
|
||||
@@ -29,5 +22,3 @@ rouge-score==0.1.2
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
|
||||
52
scripts/alpaca_json_to_jsonl.py
Normal file
52
scripts/alpaca_json_to_jsonl.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Module to convert json file to jsonl"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import fire
|
||||
|
||||
from axolotl.convert import (
|
||||
FileReader,
|
||||
FileWriter,
|
||||
JsonlSerializer,
|
||||
JsonParser,
|
||||
JsonToJsonlConverter,
|
||||
StdoutWriter,
|
||||
)
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
configure_logging()
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
|
||||
def main(
|
||||
file: Path,
|
||||
output: Optional[Path] = None,
|
||||
to_stdout: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Convert a json file to jsonl
|
||||
"""
|
||||
|
||||
file_reader = FileReader()
|
||||
writer: Union[StdoutWriter, FileWriter]
|
||||
if to_stdout or output is None:
|
||||
writer = StdoutWriter()
|
||||
else:
|
||||
writer = FileWriter(output)
|
||||
json_parser = JsonParser()
|
||||
jsonl_serializer = JsonlSerializer()
|
||||
|
||||
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||
|
||||
converter.convert(file, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
@@ -1,54 +1,315 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
do_inference,
|
||||
do_merge_lora,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.cli.shard import shard
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
LOG = logging.getLogger("axolotl.scripts.finetune")
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
LOG.warning(
|
||||
str(
|
||||
PendingDeprecationWarning(
|
||||
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
|
||||
)
|
||||
def print_axolotl_text_art():
|
||||
ascii_art = """
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
"""
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to finish): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||
return instruction
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
)
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if parsed_cli_args.inference:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
elif parsed_cli_args.merge_lora:
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
elif parsed_cli_args.shard:
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
if not yaml_files:
|
||||
raise ValueError(
|
||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||
)
|
||||
|
||||
print("Choose a YAML file:")
|
||||
for idx, file in enumerate(yaml_files):
|
||||
print(f"{idx + 1}. {file}")
|
||||
|
||||
chosen_file = None
|
||||
while chosen_file is None:
|
||||
try:
|
||||
choice = int(input("Enter the number of your choice: "))
|
||||
if 1 <= choice <= len(yaml_files):
|
||||
chosen_file = yaml_files[choice - 1]
|
||||
else:
|
||||
print("Invalid choice. Please choose a number from the list.")
|
||||
except ValueError:
|
||||
print("Invalid input. Please enter a number.")
|
||||
|
||||
return chosen_file
|
||||
|
||||
|
||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def train(
|
||||
config: Path = Path("configs/"),
|
||||
prepare_ds_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
print_axolotl_text_art()
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
# load the tokenizer first
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
||||
),
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
if prepare_ds_only:
|
||||
LOG.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info("saving merged model")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if cfg.inference:
|
||||
LOG.info("calling do_inference function")
|
||||
prompter: Optional[str] = "AlpacaPrompter"
|
||||
if "prompter" in kwargs:
|
||||
if kwargs["prompter"] == "None":
|
||||
prompter = None
|
||||
else:
|
||||
prompter = kwargs["prompter"]
|
||||
do_inference(cfg, model, tokenizer, prompter=prompter)
|
||||
return
|
||||
|
||||
if "shard" in kwargs:
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
return
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||
LOG.info("Compiling torch model")
|
||||
model = torch.compile(model)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||
)
|
||||
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(cfg.output_dir)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
fire.Fire(train)
|
||||
|
||||
56
setup.py
56
setup.py
@@ -2,53 +2,31 @@
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif (
|
||||
"flash-attn" not in line
|
||||
and "deepspeed" not in line
|
||||
and line
|
||||
and line[0] != "#"
|
||||
):
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||
if "torch==2.1.0" in _install_requires:
|
||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||
_install_requires.append(
|
||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||
)
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
install_requires, dependency_links = parse_requirements()
|
||||
|
||||
install_requires = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
# don't include peft yet until we check the int4
|
||||
# need to manually install peft for now...
|
||||
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||
reqs = [r for r in reqs if r and r[0] != "#"]
|
||||
for r in reqs:
|
||||
install_requires.append(r)
|
||||
|
||||
setup(
|
||||
name="axolotl",
|
||||
version="0.3.0",
|
||||
description="LLM Trainer",
|
||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||
version="0.1",
|
||||
description="You know you're going to axolotl questions",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn>=2.3.0",
|
||||
"gptq": [
|
||||
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"deepspeed": [
|
||||
"gptq_triton": [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"extras": [
|
||||
"flash-attn",
|
||||
"deepspeed",
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def print_axolotl_text_art(suffix=None):
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
ascii_text += f" x {suffix}"
|
||||
ascii_art = text2art(" axolotl", font=font)
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||
return instruction
|
||||
|
||||
|
||||
def do_merge_lora(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
if not yaml_files:
|
||||
raise ValueError(
|
||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||
)
|
||||
|
||||
if len(yaml_files) == 1:
|
||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||
return yaml_files[0]
|
||||
|
||||
print("Choose a YAML file:")
|
||||
for idx, file in enumerate(yaml_files):
|
||||
print(f"{idx + 1}. {file}")
|
||||
|
||||
chosen_file = None
|
||||
while chosen_file is None:
|
||||
try:
|
||||
choice = int(input("Enter the number of your choice: "))
|
||||
if 1 <= choice <= len(yaml_files):
|
||||
chosen_file = yaml_files[choice - 1]
|
||||
else:
|
||||
print("Invalid choice. Please choose a number from the list.")
|
||||
except ValueError:
|
||||
print("Invalid input. Please enter a number.")
|
||||
|
||||
return chosen_file
|
||||
|
||||
|
||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
|
||||
def check_accelerate_default_config():
|
||||
if Path(config_args.default_yaml_config_file).exists():
|
||||
LOG.warning(
|
||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||
)
|
||||
|
||||
|
||||
def check_user_token():
|
||||
# Verify if token is valid
|
||||
api = HfApi()
|
||||
try:
|
||||
user_info = api.whoami()
|
||||
return bool(user_info)
|
||||
except LocalTokenNotFoundError:
|
||||
LOG.warning(
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
CLI to run inference on a trained model
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
CLI to run merge a trained LoRA into a base model
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.merge_lora = True
|
||||
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
CLI to shard a trained model into 10GiB chunks
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
|
||||
def shard(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
LOG.debug("Re-saving model w/ sharding")
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.shard = True
|
||||
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
shared module for cli specific things
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.common.cli")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerCliArgs:
|
||||
"""
|
||||
dataclass representing the various non-training arguments
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
return model, tokenizer
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
"""
|
||||
|
||||
@@ -38,15 +38,10 @@ class TokenizedPromptDataset(Dataset):
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -55,7 +50,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
@@ -1,43 +1,16 @@
|
||||
"""
|
||||
Common logging module for axolotl
|
||||
"""
|
||||
"""Logging configuration settings"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from logging import Formatter
|
||||
from logging.config import dictConfig
|
||||
from typing import Any, Dict
|
||||
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
|
||||
class ColorfulFormatter(Formatter):
|
||||
"""
|
||||
Formatter to add coloring to log messages by log type
|
||||
"""
|
||||
|
||||
COLORS = {
|
||||
"WARNING": Fore.YELLOW,
|
||||
"ERROR": Fore.RED,
|
||||
"CRITICAL": Fore.RED + Style.BRIGHT,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
log_message = super().format(record)
|
||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||
|
||||
|
||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
},
|
||||
"colorful": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
"handlers": {
|
||||
@@ -47,25 +20,14 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
"color_console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "colorful",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
},
|
||||
},
|
||||
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
|
||||
"loggers": {
|
||||
"axolotl": {
|
||||
"handlers": ["color_console"],
|
||||
"level": "DEBUG",
|
||||
"propagate": False,
|
||||
},
|
||||
"axolotl": {"handlers": ["console"], "level": "DEBUG", "propagate": False},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging():
|
||||
"""Configure with default logging"""
|
||||
init() # Initialize colorama
|
||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""
|
||||
MixFormers model architecture used for phi models
|
||||
"""
|
||||
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||
@@ -1,63 +0,0 @@
|
||||
# pylint: skip-file
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MixFormerSequentialConfig(PretrainedConfig):
|
||||
"""MixFormer (sequential for DeepSpeed) configuration."""
|
||||
|
||||
model_type = "mixformer-sequential"
|
||||
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
|
||||
"blocks": "architecture", # `blocks` key is for backward compatibility
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: Optional[int] = 50304,
|
||||
n_positions: Optional[int] = 2048,
|
||||
n_embd: Optional[int] = 1024,
|
||||
n_layer: Optional[int] = 20,
|
||||
n_inner: Optional[int] = None,
|
||||
n_head: Optional[int] = 16,
|
||||
rotary_dim: Optional[int] = 32,
|
||||
activation_function: Optional[str] = "gelu_new",
|
||||
embd_layer: Optional[str] = "default",
|
||||
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
|
||||
embd_pdrop: Optional[float] = 0.0,
|
||||
resid_pdrop: Optional[float] = 0.0,
|
||||
layer_norm_epsilon: Optional[float] = 1e-5,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
pad_vocab_size_multiple: Optional[int] = 64,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.vocab_size = int(
|
||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||
self.activation_function = activation_function
|
||||
self.embd_layer = embd_layer
|
||||
self.architecture = architecture
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
@@ -1,930 +0,0 @@
|
||||
# pylint: skip-file
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_qkvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
max_sequence_len: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
fused_ft_kernel: bool = False
|
||||
lengths_per_sample: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
"""Token embedding with dropout."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.wte(input_ids)
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base: Optional[int] = 10000,
|
||||
scale_base: Optional[float] = None,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if scale_base is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Generate and save the inverse frequency buffer (non-trainable)
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.scale_base = scale_base
|
||||
self.device = device
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
||||
/ (1.4 * dim)
|
||||
if scale_base is not None
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _update_cos_sin_cache(
|
||||
self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
|
||||
) -> None:
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
seqlen = x.shape[1] + seqlen_offset
|
||||
|
||||
# Re-generate the inverse frequency buffer if it's not fp32
|
||||
# (for instance if model.half() was called)
|
||||
if self.inv_freq.dtype != "torch.float32":
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (
|
||||
torch.arange(
|
||||
0, self.dim, 2, device=self.device, dtype=torch.float32
|
||||
)
|
||||
/ self.dim
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != x.device
|
||||
or self._cos_cached.dtype != x.dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
|
||||
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(
|
||||
t, self.inv_freq.to(device=t.device, dtype=torch.float32)
|
||||
)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(
|
||||
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
||||
)
|
||||
- seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(
|
||||
power, "s -> s 1"
|
||||
)
|
||||
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def apply_rotary_emb_qkv(
|
||||
self,
|
||||
qkv: torch.FloatTensor,
|
||||
sin: torch.FloatTensor,
|
||||
cos: torch.FloatTensor,
|
||||
sin_k: Optional[torch.FloatTensor] = None,
|
||||
cos_k: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
_, seqlen, three, _, headdim = qkv.shape
|
||||
assert three == 3
|
||||
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
assert (
|
||||
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
)
|
||||
|
||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||
|
||||
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
||||
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
||||
|
||||
# Splits the queries and keys in half
|
||||
q1, q2 = q_rot.chunk(2, dim=-1)
|
||||
k1, k2 = k_rot.chunk(2, dim=-1)
|
||||
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
|
||||
sin[:seqlen], "s d -> s 1 d"
|
||||
)
|
||||
|
||||
# Casts to fp32 are necessary to prevent fp16 overflow issues
|
||||
q1, q2, k1, k2, c, s = [
|
||||
t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
|
||||
]
|
||||
|
||||
# Computes the new keys and queries, recasting to original dtype
|
||||
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
||||
|
||||
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
||||
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
||||
qkv[:, :, 2:3, :, :],
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, qkv: torch.Tensor, seqlen_offset: int = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform the forward pass.
|
||||
|
||||
Args:
|
||||
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
|
||||
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
|
||||
|
||||
Returns:
|
||||
New `qkv` and the cached sinusoids.
|
||||
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(qkv, seqlen_offset)
|
||||
|
||||
return self.apply_rotary_emb_qkv(
|
||||
qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
|
||||
)
|
||||
|
||||
|
||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
num_heads, head_dim = kv.shape[-2:]
|
||||
if layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size,
|
||||
inference_params.max_sequence_len,
|
||||
2,
|
||||
num_heads,
|
||||
head_dim,
|
||||
dtype=kv.dtype,
|
||||
device=kv.device,
|
||||
)
|
||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||
else:
|
||||
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (
|
||||
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
|
||||
)
|
||||
assert sequence_end <= (
|
||||
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
|
||||
)
|
||||
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Multi-Layer Perceptron.
|
||||
|
||||
Reference:
|
||||
Attention Is All You Need.
|
||||
https://arxiv.org/pdf/1706.03762.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
n_inner: Optional[int] = None,
|
||||
act_fn: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
act_fn = config.activation_function if act_fn is None else act_fn
|
||||
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
||||
|
||||
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
||||
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
||||
|
||||
self.fc1 = nn.Linear(config.n_embd, n_inner)
|
||||
self.fc2 = nn.Linear(n_inner, config.n_embd)
|
||||
self.act = ACT2FN[act_fn]
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
old_keys = [
|
||||
prefix + "fc_in.weight",
|
||||
prefix + "fc_out.weight",
|
||||
prefix + "fc_in.bias",
|
||||
prefix + "fc_out.bias",
|
||||
]
|
||||
new_keys = [
|
||||
prefix + "fc1.weight",
|
||||
prefix + "fc2.weight",
|
||||
prefix + "fc1.bias",
|
||||
prefix + "fc2.bias",
|
||||
]
|
||||
|
||||
if all(k in state_dict for k in old_keys) and not all(
|
||||
k in state_dict for k in new_keys
|
||||
):
|
||||
# Older version of `MLP` saved with different key names.
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
return super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMLP(nn.Module):
|
||||
"""Fused Multi-Layer Perceptron from `flash-attn`.
|
||||
|
||||
Reference:
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
n_inner: Optional[int] = None,
|
||||
act_fn: Optional[str] = None,
|
||||
raise_on_missing: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
act_fn = config.activation_function if act_fn is None else act_fn
|
||||
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
||||
|
||||
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
||||
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
||||
|
||||
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
|
||||
activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
|
||||
|
||||
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return self.mlp(hidden_states)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(
|
||||
self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, S)
|
||||
"""
|
||||
causal = self.causal if causal is None else causal
|
||||
if cu_seqlens is not None:
|
||||
return flash_attn_varlen_qkvpacked_func(
|
||||
qkv.squeeze(0),
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=self.drop.p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
else:
|
||||
return flash_attn_qkvpacked_func(
|
||||
qkv,
|
||||
dropout_p=self.drop.p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, Sk)
|
||||
"""
|
||||
causal = self.causal if causal is None else causal
|
||||
return flash_attn_kvpacked_func(
|
||||
q,
|
||||
kv,
|
||||
dropout_p=self.drop.p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
|
||||
def find_mha_dims(
|
||||
config: PretrainedConfig,
|
||||
n_head: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""Validate and return the number of heads and head dimension for multi-head attention.
|
||||
|
||||
Args:
|
||||
config: Model configuration.
|
||||
n_head: Number of heads.
|
||||
head_dim: Head dimension.
|
||||
|
||||
Returns:
|
||||
Number of heads and head dimension.
|
||||
|
||||
"""
|
||||
|
||||
assert all(
|
||||
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
||||
), "`config` must have `n_embd` and `n_head` attributes."
|
||||
|
||||
if head_dim is None:
|
||||
assert (
|
||||
config.n_embd % config.n_head == 0
|
||||
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
||||
|
||||
if n_head is None and head_dim is None:
|
||||
head_dim = config.n_embd // config.n_head
|
||||
n_head = config.n_head
|
||||
elif n_head is None or head_dim is None:
|
||||
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
||||
|
||||
return n_head, head_dim
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
"""Multi-head attention layer.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
rotary_dim: Optional[int] = None,
|
||||
n_head: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
bias: Optional[bool] = True,
|
||||
dropout: Optional[float] = 0.0,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: Optional[bool] = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
rotary_emb_scale_base: Optional[float] = None,
|
||||
return_residual: Optional[bool] = False,
|
||||
checkpointing: Optional[bool] = False,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
fused_dense: Optional[bool] = True,
|
||||
flash_attn: Optional[bool] = True,
|
||||
cutlass_attn: Optional[bool] = False,
|
||||
flash_rotary: Optional[bool] = True,
|
||||
raise_on_missing: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
|
||||
|
||||
self.hidden_size = config.n_embd
|
||||
self.n_head = n_head
|
||||
self.head_dim = head_dim
|
||||
self.op_size = n_head * head_dim
|
||||
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.rotary_emb_dim = (
|
||||
rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
||||
)
|
||||
self.fused_dense = fused_dense
|
||||
self.flash_attn = flash_attn
|
||||
self.cutlass_attn = cutlass_attn
|
||||
self.flash_rotary = flash_rotary
|
||||
self.return_residual = return_residual
|
||||
self.checkpointing = checkpointing
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
rotary_kwargs = {"device": device}
|
||||
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
||||
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
||||
else:
|
||||
pass
|
||||
|
||||
self.Wqkv = nn.Linear(
|
||||
self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.out_proj = nn.Linear(
|
||||
self.op_size, self.hidden_size, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.inner_attn = SelfAttention(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
self.inner_cross_attn = CrossAttention(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
|
||||
def _update_kv_cache(
|
||||
self, kv: torch.FloatTensor, inference_params: InferenceParams
|
||||
) -> None:
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
assert (
|
||||
self.layer_idx is not None
|
||||
), "Generation requires layer_idx in the constructor"
|
||||
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.FloatTensor,
|
||||
x_kv: Optional[torch.FloatTensor] = None,
|
||||
key_padding_mask: Optional[torch.BoolTensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
mixer_subset: Optional[torch.LongTensor] = None,
|
||||
past_cache: Optional[InferenceParams] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Perform the forward pass.
|
||||
|
||||
Args:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
||||
is the is the sum of the sequence lengths in the batch.
|
||||
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
||||
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
||||
(batch, seqlen). Only applicable when not using FlashAttention.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into x. Only applicable when using
|
||||
FlashAttention.
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
past_cache: For generation only.
|
||||
|
||||
Returns:
|
||||
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
|
||||
else (total, hidden_dim) where total is the is the sum of the sequence lengths
|
||||
in the batch.
|
||||
|
||||
"""
|
||||
|
||||
if cu_seqlens is not None:
|
||||
assert max_seqlen is not None
|
||||
assert key_padding_mask is None
|
||||
assert self.flash_attn
|
||||
# assert self.rotary_emb_dim == 0
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert cu_seqlens is None
|
||||
assert max_seqlen is None
|
||||
assert not self.flash_attn
|
||||
|
||||
if past_cache is not None:
|
||||
assert key_padding_mask is None
|
||||
assert cu_seqlens is None and max_seqlen is None
|
||||
|
||||
attn_kwargs = {"key_padding_mask": key_padding_mask}
|
||||
|
||||
assert x_kv is None and mixer_subset is None
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = rearrange(
|
||||
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
|
||||
)
|
||||
|
||||
if past_cache is None:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
context = self.inner_attn(
|
||||
qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if past_cache.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
|
||||
out = rearrange(context, "... h d -> ... (h d)")
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
"""Parallel block.
|
||||
|
||||
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
mixer: Optional[Dict[str, Any]] = None,
|
||||
mlp: Optional[Dict[str, Any]] = None,
|
||||
block_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.mixer = MHA(config, layer_idx=block_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
past_cache: Optional[torch.FloatTensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln(hidden_states)
|
||||
|
||||
attn_outputs = self.mixer(
|
||||
hidden_states,
|
||||
past_cache=past_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if isinstance(attn_outputs, tuple):
|
||||
attn_outputs = attn_outputs[0]
|
||||
|
||||
attn_outputs = self.resid_dropout(attn_outputs)
|
||||
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
||||
|
||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CausalLMHead(nn.Module):
|
||||
"""Causal Language Modeling head.
|
||||
|
||||
Reference:
|
||||
Improving Language Understanding by Generative Pre-Training.
|
||||
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.linear = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.ln(hidden_states)
|
||||
logits = self.linear(hidden_states).to(torch.float32)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class CausalLMLoss(nn.Module):
|
||||
"""Causal Language Modeling loss.
|
||||
|
||||
Reference:
|
||||
Improving Language Understanding by Generative Pre-Training.
|
||||
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, shift_labels: Optional[bool] = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.shift_labels = shift_labels
|
||||
self.loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self, logits: torch.FloatTensor, labels: torch.LongTensor
|
||||
) -> torch.FloatTensor:
|
||||
if self.shift_labels:
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
||||
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
|
||||
|
||||
config_class = MixFormerSequentialConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs) -> None:
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
if "use_cache" in kwargs and not kwargs["use_cache"]:
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
if past_key_values is None or not (
|
||||
isinstance(past_key_values, InferenceParams)
|
||||
):
|
||||
past_key_values = InferenceParams(
|
||||
max_batch_size=input_ids.shape[0],
|
||||
max_sequence_len=self.config.n_positions,
|
||||
sequence_len_offset=0,
|
||||
batch_size_offset=0,
|
||||
fused_ft_kernel=False,
|
||||
key_value_memory_dict={},
|
||||
)
|
||||
else:
|
||||
# assume past_key_values has cached all but last token in input_ids
|
||||
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
|
||||
|
||||
|
||||
class PackedSequential(nn.Sequential):
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
for module in self:
|
||||
sig = inspect.signature(module.forward)
|
||||
if "cu_seqlens" in sig.parameters:
|
||||
input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
||||
else:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
|
||||
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
||||
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
||||
|
||||
_keys_to_ignore_on_load_missing = [""]
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
|
||||
]
|
||||
_no_split_modules = ["ParallelBlock"]
|
||||
|
||||
def __init__(self, config: MixFormerSequentialConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
modules = [Embedding(config)]
|
||||
block_config = config.architecture
|
||||
|
||||
if not isinstance(block_config, list):
|
||||
block_config = [block_config for _ in range(config.n_layer)]
|
||||
|
||||
if config.n_layer != len(block_config):
|
||||
config.n_layer = len(block_config)
|
||||
|
||||
for block_idx, block in enumerate(block_config):
|
||||
# `block_cls` with `legacy` value is for backward compatibility
|
||||
# `path` key is for backward compatibility
|
||||
block = copy.deepcopy(block) or {"block_cls": "parallel"}
|
||||
block.pop("path", None) or block.pop("block_cls", None)
|
||||
|
||||
block["block_idx"] = block_idx
|
||||
modules.append(ParallelBlock(config, **block))
|
||||
|
||||
modules.append(CausalLMHead(config))
|
||||
|
||||
self.layers = PackedSequential(*modules)
|
||||
self.loss = CausalLMLoss()
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.layers[0].wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||
self.layers[0].wte = new_embeddings
|
||||
|
||||
def get_output_embeddings(self) -> nn.Linear:
|
||||
return self.layers[-1].linear
|
||||
|
||||
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
||||
self.layers[-1].linear = new_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
cu_seqlens: Optional[torch.LongTensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
if position_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if not past_key_values:
|
||||
lm_logits = self.layers(
|
||||
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
||||
)
|
||||
else:
|
||||
hidden_layer = self.layers[0](input_ids)
|
||||
for module in self.layers[1:-1]:
|
||||
hidden_layer = module(
|
||||
hidden_layer,
|
||||
past_cache=past_key_values,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
lm_logits = self.layers[-1](hidden_layer)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss(lm_logits, labels)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss, logits=lm_logits, past_key_values=past_key_values
|
||||
)
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
Flash attention monkey patch for cerebras btlm model
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_btlm to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_btlm", ".modeling_btlm"
|
||||
)
|
||||
modeling_btlm = importlib.import_module(module_name)
|
||||
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
|
||||
flashattn_attn
|
||||
)
|
||||
|
||||
|
||||
def flashattn_attn(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
value: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
softmax_scale = (
|
||||
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
|
||||
)
|
||||
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
# Perform Flash attention
|
||||
attn_output = flash_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=0.0, # Assuming you have this attribute
|
||||
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
|
||||
causal=not self.is_cross_attention, # Assuming you have this attribute
|
||||
return_attn_probs=False, # Set this based on your needs
|
||||
)
|
||||
|
||||
# Optional: Apply head mask if it's not None
|
||||
if head_mask is not None:
|
||||
attn_output *= head_mask
|
||||
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
return attn_output, None # We don't have explicit attn_weights in Flash attention
|
||||
@@ -1,174 +0,0 @@
|
||||
"""
|
||||
monkeypatch to add a get_turns method
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
ret = ""
|
||||
for role, msg in self.get_turns():
|
||||
ret += role + msg
|
||||
return ret
|
||||
|
||||
|
||||
def get_turns( # pylint: disable=too-many-return-statements
|
||||
self,
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message + seps[i % 2]
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ": ", "" # must be end with a space
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role, message + self.sep
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role, message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.RWKV:
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||
"\n\n", "\n"
|
||||
) + "\n\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
else:
|
||||
yield "", "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages[1:]):
|
||||
if message:
|
||||
yield role + " ", message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if system_prompt:
|
||||
yield "", system_prompt + self.sep
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
yield f"{role}:", f"{message}{self.sep}"
|
||||
else:
|
||||
yield f"{role}:", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATML:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep + "\n"
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<s>" if i % 2 == 0 else ""
|
||||
if message:
|
||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
suffix = "\n\n" if i % 2 == 1 else ""
|
||||
yield role + ":\n", message + seps[i % 2] + suffix
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", "<s>" + message + "</s>"
|
||||
else:
|
||||
yield role + ": " + "<s>", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ROBIN:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ":\n", message + self.sep
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||
if self.system_message:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def add_get_turns_to_conversation():
|
||||
import fastchat.conversation
|
||||
|
||||
fastchat.conversation.Conversation.get_turns = get_turns
|
||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -35,14 +33,7 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
):
|
||||
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
@@ -53,38 +44,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
@@ -107,7 +66,6 @@ def flashattn_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
@@ -116,8 +74,6 @@ def flashattn_forward(
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
@@ -153,13 +109,6 @@ def flashattn_forward(
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
if query_states.dtype == torch.float32:
|
||||
query_states = query_states.to(dtype=original_dtype)
|
||||
if key_states.dtype == torch.float32:
|
||||
key_states = key_states.to(dtype=original_dtype)
|
||||
if value_states.dtype == torch.float32:
|
||||
value_states = value_states.to(dtype=original_dtype)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
@@ -209,9 +158,9 @@ def flashattn_forward(
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
is_causal = past_key_value is not None
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
if cu_seqlens is not None and max_seqlen is not None:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
@@ -220,7 +169,7 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -279,8 +228,6 @@ def flashattn_forward(
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
@@ -318,10 +265,6 @@ def flashattn_forward(
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# handle conversion back for IA3
|
||||
if attn_output.dtype == torch.float32:
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
@@ -498,13 +441,6 @@ def llama_model_forward(
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
@@ -515,7 +451,6 @@ def llama_model_forward(
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -540,9 +475,7 @@ def llama_model_forward(
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(
|
||||
*inputs,
|
||||
)
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@@ -551,10 +484,9 @@ def llama_model_forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
None,
|
||||
output_attentions,
|
||||
None,
|
||||
padding_mask,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
@@ -566,17 +498,12 @@ def llama_model_forward(
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# handle conversion back for IA3
|
||||
if hidden_states.dtype == torch.float32:
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
@@ -617,7 +544,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
@@ -650,7 +576,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
||||
)
|
||||
@@ -1,640 +0,0 @@
|
||||
"""Flash attention monkey patch for mistral model"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||
mistral_model_forward
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def mistral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
"""
|
||||
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.mistral.modeling_mistral
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
||||
)
|
||||
@@ -1,393 +0,0 @@
|
||||
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
|
||||
if key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
if not os.path.exists(self.last_full_model):
|
||||
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
||||
|
||||
assert os.path.exists(
|
||||
self.last_full_model
|
||||
), "for ReLORA base_model must be a local path"
|
||||
|
||||
self.num_lora_restarts = 0
|
||||
self.need_full_save = False
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**_kwargs,
|
||||
):
|
||||
if self.resume_from_checkpoint:
|
||||
weight_path = os.path.join(self.resume_from_checkpoint, "relora")
|
||||
if not os.path.exists(weight_path):
|
||||
LOG.warning(
|
||||
"Resuming ReLoRA from checkpoint, but no full-weight save found"
|
||||
)
|
||||
else:
|
||||
LOG.info(f"Loading adjusted base weights from {weight_path}")
|
||||
load_weight_checkpoint(model, weight_path)
|
||||
return control
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
"relora",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
checkpoint_folder,
|
||||
reinit=True,
|
||||
quantized=self.quantized,
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantized:
|
||||
self.last_full_model = checkpoint_folder
|
||||
self.num_lora_restarts += 1
|
||||
|
||||
return control
|
||||
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**_kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
):
|
||||
if self.quantized:
|
||||
if is_main_process() and self.last_full_model != checkpoint_folder:
|
||||
# ensure the latest full parameter save is in the latest checkpoint
|
||||
# folder, so that automatic pruning of checkpoints does not remove it
|
||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||
os.makedirs(checkpoint_folder, exist_ok=True)
|
||||
chunks = glob.glob(
|
||||
f"{self.last_full_model}/model*.safetensors"
|
||||
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
||||
for path in chunks:
|
||||
new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
|
||||
try:
|
||||
os.symlink(new_path, path)
|
||||
except OSError:
|
||||
# probably on windows without permission to symlink
|
||||
pass
|
||||
|
||||
self.last_full_model = checkpoint_folder
|
||||
else:
|
||||
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Dict[str, float],
|
||||
**_kwargs,
|
||||
):
|
||||
logs["num_lora_restarts"] = self.num_lora_restarts
|
||||
return control
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
_state: TrainerState,
|
||||
control: TrainerControl,
|
||||
model: peft.LoraModel,
|
||||
**_kwargs,
|
||||
):
|
||||
if self.quantized:
|
||||
# perform final merge and save
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
self.last_full_model,
|
||||
args.output_dir,
|
||||
reinit=False,
|
||||
quantized=self.quantized,
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
# no need to save if unquantized, as finetune.py will call merge_and_unload()
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
str(Path(path) / f"{model_name}.index.json")
|
||||
):
|
||||
model_name = "pytorch_model.bin"
|
||||
|
||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||
if os.path.exists(index_path):
|
||||
with open(index_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
return data["weight_map"]
|
||||
return {(module_name + ".weight"): model_name for module_name in module_names}
|
||||
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
||||
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight.detach().to(device)
|
||||
@ layer.lora_A[adapter].weight.detach().to(device),
|
||||
getattr(layer, "fan_in_fan_out", False),
|
||||
)
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
|
||||
return layer.get_delta_weight().to(device)
|
||||
|
||||
|
||||
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
||||
modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
|
||||
|
||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
_parent, target, _target_name = peft.utils._get_submodules(model.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||
modules[key] = target
|
||||
|
||||
return modules
|
||||
|
||||
|
||||
def update_weights(
|
||||
target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
|
||||
):
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
# This could be faster, but the quantization of Linear4bit weights occurs
|
||||
# when the module is moved from cpu to gpu. Without meddling *too* deeply in
|
||||
# PEFT's innards or maintaining a duplicate of that codepath, this is good
|
||||
# enough for now.
|
||||
target.weight.quant_state = None
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
||||
else:
|
||||
target.weight.data = new_weight.to(device)
|
||||
|
||||
|
||||
def merge_and_save(
|
||||
model: peft.LoraModel,
|
||||
model_src: str,
|
||||
model_dst: str,
|
||||
reinit: bool = False,
|
||||
quantized: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
actually_save: bool = True,
|
||||
):
|
||||
modules = find_lora_modules(model)
|
||||
|
||||
if not quantized:
|
||||
for module_name, target in modules.items():
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
shard_paths = sharded_paths(model_src, modules.keys())
|
||||
out_shard_paths = {}
|
||||
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
for shard_path in unique_shards:
|
||||
out_tensors = {}
|
||||
if shard_path.endswith(".safetensors"):
|
||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||
else:
|
||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||
if "state_dict" in in_tensors:
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
for module_name, target in modules.items():
|
||||
key = module_name + ".weight"
|
||||
if key not in shard_paths or shard_paths[key] != shard_path:
|
||||
continue
|
||||
|
||||
orig_weight = in_tensors[key]
|
||||
old_dev = target.weight.device
|
||||
math_dev = "cpu" if cpu_offload else old_dev
|
||||
|
||||
delta_weight = lora_delta_weight(target, math_dev)
|
||||
new_weight = orig_weight.to(math_dev) + delta_weight
|
||||
del delta_weight
|
||||
|
||||
if actually_save:
|
||||
out_tensors[key] = new_weight.half().cpu()
|
||||
|
||||
update_weights(target, new_weight, reinit=reinit, device=old_dev)
|
||||
|
||||
if actually_save:
|
||||
out_shard_name = shard_path
|
||||
if out_shard_name.startswith("pytorch_model"):
|
||||
out_shard_name = (
|
||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||
+ ".safetensors"
|
||||
)
|
||||
|
||||
for module_name in in_tensors:
|
||||
if module_name not in out_tensors:
|
||||
out_tensors[module_name] = in_tensors[module_name].half()
|
||||
out_shard_paths[module_name] = out_shard_name
|
||||
|
||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
||||
|
||||
del in_tensors
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if actually_save and len(unique_shards) > 1:
|
||||
with open(
|
||||
str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
|
||||
) as file:
|
||||
json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
|
||||
|
||||
|
||||
def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
|
||||
modules = find_lora_modules(model)
|
||||
shard_paths = sharded_paths(checkpoint_path, modules.keys())
|
||||
unique_shards = list(set(shard_paths.values()))
|
||||
|
||||
for shard_path in unique_shards:
|
||||
tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
|
||||
|
||||
for module_name, target in modules.items():
|
||||
key = module_name + ".weight"
|
||||
if key not in shard_paths or shard_paths[key] != shard_path:
|
||||
continue
|
||||
|
||||
new_weight = tensors[key]
|
||||
update_weights(
|
||||
target, new_weight, reinit=False, device=target.weight.device
|
||||
)
|
||||
@@ -1,415 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This code is based off the following work:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
""" PyTorch StableLM Epoch model. """
|
||||
import importlib
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import init_empty_weights
|
||||
from einops import rearrange
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_stablelm_epoch to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
|
||||
)
|
||||
modeling_stablelm = importlib.import_module(module_name)
|
||||
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
|
||||
flashattn_attn
|
||||
)
|
||||
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
|
||||
stablelm_model_forward
|
||||
)
|
||||
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
|
||||
decoder_layer_forward
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
# pylint: disable=invalid-name
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
# pylint: disable=invalid-name
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def flashattn_attn(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
query_rot = query_states[..., : self.rotary_ndims]
|
||||
query_pass = query_states[..., self.rotary_ndims :]
|
||||
key_rot = key_states[..., : self.rotary_ndims]
|
||||
key_pass = key_states[..., self.rotary_ndims :]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_rot, key_rot, cos, sin, position_ids
|
||||
)
|
||||
|
||||
# [batch_size, num_heads, seq_len, head_dim]
|
||||
query_states = torch.cat((query_states, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_states, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Reuse k, v, self_attention
|
||||
key_states = torch.cat((past_key_value[0], key_states), dim=2)
|
||||
value_states = torch.cat((past_key_value[1], value_states), dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# Repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
softmax_scale = None
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
|
||||
)
|
||||
|
||||
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# Upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
# Merge heads
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
# Final linear projection
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
def decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
|
||||
]:
|
||||
# pylint: disable=duplicate-code
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stablelm_model_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# Retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# Embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
@@ -1,12 +1,9 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
def load(strategy, tokenizer, cfg):
|
||||
try:
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
@@ -14,13 +11,6 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
else:
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
return func(tokenizer, cfg)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(prompt_style),
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -57,8 +57,6 @@ class SystemDataPrompter(AlpacaPrompter):
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
system_format: str = "### System:\n{system}\n\n"
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"""
|
||||
Basic completion text
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
|
||||
|
||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Completion prompts.
|
||||
"""
|
||||
|
||||
_field: str = "text"
|
||||
|
||||
def __init__(self, *args, max_length=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if max_length is not None:
|
||||
self.max_length = max_length
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def field(self) -> str:
|
||||
return self._field
|
||||
|
||||
@field.setter
|
||||
def field(self, new_field: str):
|
||||
self._field = new_field
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
prompt[self.field],
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
res = defaultdict(lambda: [])
|
||||
feature_names = list(prompt.keys())
|
||||
for row in zip(*prompt.values()):
|
||||
prompt_row = dict(zip(feature_names, row))
|
||||
(
|
||||
instruction,
|
||||
_,
|
||||
_,
|
||||
) = self.parse_instruction_fields(prompt_row)
|
||||
|
||||
full_prompt = self._build_full_prompt(instruction, None, None)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
for key, val in tokenized_full_prompt.items():
|
||||
for i in range(0, len(val), self.sequence_len):
|
||||
res[key].append(val[i : i + self.sequence_len])
|
||||
|
||||
return dict(res)
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response
|
||||
): # pylint: disable=redefined-builtin
|
||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
"""
|
||||
Prompter for completion
|
||||
"""
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input=None, # pylint: disable=redefined-builtin, unused-argument
|
||||
output=None, # pylint: disable=unused-argument
|
||||
) -> Generator[str, None, None]:
|
||||
yield instruction
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strat = CompletionPromptTokenizingStrategy(
|
||||
CompletionPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
if ds_cfg and "field" in ds_cfg:
|
||||
strat.field = ds_cfg["field"]
|
||||
|
||||
return strat
|
||||
@@ -24,15 +24,6 @@ def load(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
def load_v2(tokenizer, cfg):
|
||||
return ContextQaV2PromptTokenizingStrategy(
|
||||
ContextV2Prompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
@@ -59,38 +50,6 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
||||
)
|
||||
|
||||
|
||||
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question and answer
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
"Context: "
|
||||
+ prompt["context"]
|
||||
+ "\nQuestion: "
|
||||
+ prompt["question"]
|
||||
+ "\n",
|
||||
"",
|
||||
"Answer: " + prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
class ContextV2Prompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
self.turn_format = "{instruction}\n{input}"
|
||||
self.turn_no_input_format = "{instruction}"
|
||||
self.system_format = "{system}"
|
||||
|
||||
|
||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
InstructionPromptTokenizingStrategy
|
||||
):
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for the Metharme models
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (prompt["prompt"], "", prompt["generation"])
|
||||
|
||||
def _tokenize(
|
||||
self,
|
||||
prompt: str,
|
||||
add_eos_token: bool = True,
|
||||
strip_bos_token: bool = False,
|
||||
num_eos_tokens: int = 3,
|
||||
):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
# If there's already an EOS token there, subtract from the number added
|
||||
if result["input_ids"][-1] == self.tokenizer.eos_token_id:
|
||||
num_eos_tokens -= 1
|
||||
|
||||
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
|
||||
for _ in range(num_eos_tokens):
|
||||
if len(result["input_ids"]) < self.sequence_len:
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
class MetharmePrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for the Metharme models.
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
system_format = ""
|
||||
turn_format = "{instruction}"
|
||||
turn_no_input_format = "{instruction}"
|
||||
|
||||
def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return MetharmePromptTokenizingStrategy(
|
||||
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
)
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -1,35 +1,12 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message="You are a helpful assistant.",
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>\n",
|
||||
)
|
||||
)
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -38,7 +15,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -47,7 +24,7 @@ def load_role(tokenizer, cfg):
|
||||
|
||||
def load_guanaco(tokenizer, cfg):
|
||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -1,98 +0,0 @@
|
||||
"""
|
||||
User Defined prompts with configuration from the YML config
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
SystemDataPrompter,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserDefinedDatasetConfig:
|
||||
"""
|
||||
dataclass configuration representing a userdefined dataset type
|
||||
"""
|
||||
|
||||
system_prompt: str = ""
|
||||
field_system: str = "system"
|
||||
field_instruction: str = "instruction"
|
||||
field_input: str = "input"
|
||||
field_output: str = "output"
|
||||
format: str = "{instruction} {input} "
|
||||
no_input_format: str = "{instruction} "
|
||||
system_format: str = "{system}"
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
"""
|
||||
Prompt Tokenization Strategy for user defined prompts
|
||||
"""
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
|
||||
if not ds_cfg:
|
||||
raise ValueError("Missing dataset prompt configuration")
|
||||
|
||||
system_prompt = ""
|
||||
if ds_cfg.system_prompt:
|
||||
system_prompt = ds_cfg.system_prompt
|
||||
|
||||
def parse_instruction_fields(
|
||||
field_instruction,
|
||||
field_input,
|
||||
field_output,
|
||||
field_system,
|
||||
system_prompt,
|
||||
prompt,
|
||||
) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt[field_instruction],
|
||||
prompt[field_input] if field_input in prompt else "",
|
||||
prompt[field_output] if field_output in prompt else "",
|
||||
prompt[field_system] if field_system in prompt else system_prompt,
|
||||
)
|
||||
|
||||
turn_format = ds_cfg.format
|
||||
turn_no_input_format = ds_cfg.no_input_format
|
||||
system_format = ds_cfg.system_format
|
||||
|
||||
class UserDefinedPrompter(SystemDataPrompter):
|
||||
"""
|
||||
Prompter for user defined prompts
|
||||
"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = turn_format
|
||||
self.turn_no_input_format = turn_no_input_format
|
||||
self.system_format = system_format
|
||||
|
||||
prompter = UserDefinedPrompter()
|
||||
|
||||
strat = UserDefinedPromptTokenizationStrategy(
|
||||
prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
setattr(
|
||||
strat,
|
||||
"parse_instruction_fields",
|
||||
partial(
|
||||
parse_instruction_fields,
|
||||
ds_cfg.field_instruction,
|
||||
ds_cfg.field_input,
|
||||
ds_cfg.field_output,
|
||||
ds_cfg.field_system,
|
||||
system_prompt,
|
||||
),
|
||||
)
|
||||
return strat
|
||||
@@ -2,27 +2,22 @@
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||
add_get_turns_to_conversation,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
add_get_turns_to_conversation()
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
"""
|
||||
@@ -46,47 +41,51 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.train_on_inputs = train_on_inputs
|
||||
self.sequence_len = sequence_len
|
||||
self.max_length = sequence_len
|
||||
|
||||
@abc.abstractmethod
|
||||
def tokenize_prompt(self, prompt):
|
||||
pass
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
result: BatchEncoding
|
||||
if not prompt:
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.max_length
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
@@ -237,6 +236,23 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Completion prompts.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
full_prompt = self._build_full_prompt(prompt["text"], None, None)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response
|
||||
): # pylint: disable=redefined-builtin
|
||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||
|
||||
|
||||
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Reflection prompts.
|
||||
@@ -335,82 +351,51 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
role = (
|
||||
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else part[0]
|
||||
)
|
||||
turn = role + part[1]
|
||||
if part[0] == "USER:":
|
||||
part = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
part.strip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
elif part[0] == "ASSISTANT:":
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
role = (
|
||||
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else part[0]
|
||||
)
|
||||
turn = role + part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
part = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistent response, should end with an eos token
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
part.strip(),
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
elif part[0] == "SYSTEM:":
|
||||
part = part[1] # Ignore the system role from preamble
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {part[0]}")
|
||||
continue
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
@@ -425,31 +410,22 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
if not prompt.strip():
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Module containing prompters"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
@@ -27,7 +26,7 @@ class AlpacaPrompter:
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_format: str = "{system}"
|
||||
system_format: str
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
@@ -64,17 +63,13 @@ class AlpacaPrompter:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
res = self.system_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
@@ -136,6 +131,20 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
||||
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
"""
|
||||
Prompter for completion
|
||||
"""
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input=None, # pylint: disable=redefined-builtin, unused-argument
|
||||
output=None, # pylint: disable=unused-argument
|
||||
) -> Generator[str, None, None]:
|
||||
yield instruction
|
||||
|
||||
|
||||
class GPTeacherPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Prompter for GPTeacher
|
||||
@@ -215,6 +224,53 @@ class ReflectAlpacaPrompter:
|
||||
yield res
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
DOLLY = auto()
|
||||
|
||||
|
||||
# TODO clean this 💩 up
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: Optional[str] = None
|
||||
|
||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
||||
# seps = [self.sep, self.sep2]
|
||||
preamble = self.system + self.sep
|
||||
yield ("SYSTEM:", preamble)
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield (role + ":", " " + message)
|
||||
else:
|
||||
LOG.warning(f"role with empty message: {role}")
|
||||
yield (role + ":", "")
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
)
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
@@ -225,29 +281,34 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_style=None, # pylint: disable=unused-argument
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
self._conversation = conversation
|
||||
else:
|
||||
self._conversation = get_conv_template(conversation)
|
||||
else:
|
||||
self._conversation = get_conv_template("vicuna_v1.1")
|
||||
if role_key_human:
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
||||
if prompt_style != PromptStyle.CHAT.value:
|
||||
raise ValueError(
|
||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||
)
|
||||
system: str = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
)
|
||||
self._conversation = Conversation(
|
||||
system=system,
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
# ignore the system prompt if provided
|
||||
if source[0]["from"] == "system":
|
||||
source.pop(0)
|
||||
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
@@ -256,17 +317,14 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
)
|
||||
|
||||
conv = self._conversation.copy()
|
||||
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.set_system_message(source[0]["value"])
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
if source[0]["from"] not in roles:
|
||||
if (
|
||||
source[0]["from"] not in roles
|
||||
or roles[source[0]["from"]] != conv.roles[0]
|
||||
):
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as err:
|
||||
@@ -274,33 +332,10 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
if len(conv.messages) > 0 and (
|
||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||
):
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_turns():
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
for part in conv.get_prompt():
|
||||
yield part
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
A V2 prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.train")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainDatasetMeta:
|
||||
"""
|
||||
dataclass to capture the dataset specific options for training
|
||||
"""
|
||||
|
||||
train_dataset: Dataset
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
total_num_steps: Optional[int] = None
|
||||
|
||||
|
||||
def train(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
dataset_meta: TrainDatasetMeta,
|
||||
):
|
||||
# load the tokenizer first
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
# additionally presave the tokenizer and model configs
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return model, tokenizer
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
return model, tokenizer
|
||||
@@ -1,44 +1,13 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
import functools
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from pynvml.nvml import NVMLError
|
||||
|
||||
|
||||
def check_cuda_device(default_value):
|
||||
"""
|
||||
wraps a function and returns the default value instead of running the
|
||||
wrapped function if cuda isn't available or the device is auto
|
||||
:param default_value:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def deco(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
device = kwargs.get("device", args[0] if args else None)
|
||||
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
or device == "auto"
|
||||
or torch.device(device).type == "cpu"
|
||||
):
|
||||
return default_value
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage(device=0):
|
||||
return torch.cuda.memory_allocated(device) / 1024.0**3
|
||||
|
||||
|
||||
@check_cuda_device((0.0, 0.0, 0.0))
|
||||
def gpu_memory_usage_all(device=0):
|
||||
usage = torch.cuda.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
|
||||
@@ -46,22 +15,22 @@ def gpu_memory_usage_all(device=0):
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
except NVMLError:
|
||||
return 0.0
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
if not torch.cuda.is_available():
|
||||
return (0, 0, 0)
|
||||
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
if cache > 0:
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
"""Callbacks for Trainer class"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
@@ -26,43 +13,28 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
broadcast_dict,
|
||||
gather_scalar_from_all_ranks,
|
||||
get_world_size,
|
||||
is_distributed,
|
||||
is_main_process,
|
||||
zero_first,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
class EvalFirstStepCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""
|
||||
Callback to trigger evals on the first step
|
||||
"""
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the PEFT adapter"""
|
||||
|
||||
def on_step_end(
|
||||
def on_save(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and args.eval_steps < 1.0
|
||||
and state.global_step == 1
|
||||
):
|
||||
control.should_evaluate = True
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -122,419 +94,3 @@ class GPUStatsCallback(
|
||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||
self.logged = True
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
tokenizer("A", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("B", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("C", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("E", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||
]
|
||||
bench_split = "eval"
|
||||
|
||||
def transform_bench_subject(example):
|
||||
# Split on ':' and trim whitespace
|
||||
parts = example["subject"].split(":")
|
||||
first_part = (
|
||||
parts[0].strip().lower().replace("-", "_")
|
||||
) # Lowercase the first part
|
||||
second_part = (
|
||||
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
|
||||
) # Replace hyphens with underscores
|
||||
|
||||
# Return the transformed values
|
||||
return {"name": first_part, "subject": second_part}
|
||||
|
||||
if trainer.args.bench_dataset == "mmlu-zs":
|
||||
bench_dataset = load_dataset(
|
||||
"openaccess-ai-collective/mmlu-evals",
|
||||
data_files={
|
||||
"eval": "zero_shot_mmlu_val.json",
|
||||
"test": "zero_shot_mmlu_test.json",
|
||||
},
|
||||
)
|
||||
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||
# MMLU Five-shot (Eval/Test only)
|
||||
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
|
||||
bench_dataset = load_dataset(
|
||||
"openaccess-ai-collective/mmlu-evals",
|
||||
data_files={
|
||||
"eval": "five_shot_mmlu_val.json",
|
||||
"test": "five_shot_mmlu_test.json",
|
||||
},
|
||||
)
|
||||
# bench_dataset = bench_dataset.remove_columns('subject')
|
||||
elif "/" in trainer.args.bench_dataset:
|
||||
bench_ds = trainer.args.bench_dataset
|
||||
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
|
||||
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
|
||||
bench_dataset = load_dataset(
|
||||
bench_ds_name,
|
||||
data_files={
|
||||
"eval": bench_ds_data_file,
|
||||
},
|
||||
)
|
||||
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
||||
)
|
||||
bench_dataset = bench_dataset[trainer.args.bench_split]
|
||||
if trainer.args.max_bench_samples is not None:
|
||||
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
|
||||
|
||||
def tokenize_evals(example):
|
||||
source = f"{tokenizer.bos_token}{example['input']}"
|
||||
target = f"{example['output']}{tokenizer.eos_token}"
|
||||
|
||||
tokenized_source = tokenizer(
|
||||
source,
|
||||
max_length=2048,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
tokenized_target = tokenizer(
|
||||
target,
|
||||
max_length=2048,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
|
||||
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"subject": example["subject"],
|
||||
}
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||
|
||||
class BenchEvalCallback(TrainerCallback):
|
||||
"""
|
||||
TrainerCallback that runs the MMLU evals
|
||||
"""
|
||||
|
||||
def on_evaluate(
|
||||
self,
|
||||
args: AxolotlTrainingArguments,
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
data_loader = trainer.get_bench_dataloader(
|
||||
bench_dataset.remove_columns(["input", "subject", "output", "name"])
|
||||
)
|
||||
trainer.model.eval()
|
||||
preds, refs = [], []
|
||||
loss_bench = 0
|
||||
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||
(loss, logits, labels) = trainer.prediction_step(
|
||||
trainer.model,
|
||||
batch,
|
||||
prediction_loss_only=False,
|
||||
)
|
||||
# There are two tokens, the output, and eos token.
|
||||
for i, logit in enumerate(logits):
|
||||
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
|
||||
0
|
||||
][0]
|
||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||
preds.append(torch.argmax(logit_abcd).item())
|
||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||
refs += [
|
||||
abcd_idx.index(label) if label in abcd_idx else -1
|
||||
for label in labels.tolist()
|
||||
]
|
||||
loss_bench += loss.item()
|
||||
# Extract results by subject.
|
||||
bench_name = bench_dataset["name"]
|
||||
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
||||
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||
bench_names[s]["preds"].append(p)
|
||||
bench_names[s]["refs"].append(r)
|
||||
barrier()
|
||||
local_bench_names = bench_names
|
||||
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
||||
# Gather results from all GPUs to GPU 0
|
||||
|
||||
loss_bench_ranks = gather_scalar_from_all_ranks(
|
||||
lambda: loss_bench, get_world_size()
|
||||
)
|
||||
len_data_loader_ranks = gather_scalar_from_all_ranks(
|
||||
lambda: len(data_loader), get_world_size()
|
||||
)
|
||||
|
||||
results = {}
|
||||
if is_distributed() and not is_main_process():
|
||||
dist.gather_object(local_bench_names, dst=0)
|
||||
else:
|
||||
if is_distributed():
|
||||
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||
else:
|
||||
gathered_bench_names = [local_bench_names]
|
||||
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
||||
results = {f"{bench_split}_bench_loss": bench_loss}
|
||||
|
||||
# Combine results from all GPUs
|
||||
combined_bench_names: Dict[str, Dict[str, List]] = {}
|
||||
for bench_name in gathered_bench_names:
|
||||
for name, data in bench_name.items():
|
||||
if name not in combined_bench_names:
|
||||
combined_bench_names[name] = {"refs": [], "preds": []}
|
||||
combined_bench_names[name]["refs"].extend(data["refs"])
|
||||
combined_bench_names[name]["preds"].extend(data["preds"])
|
||||
|
||||
bench_scores = []
|
||||
bench_refs = []
|
||||
bench_preds = []
|
||||
for (
|
||||
bench_name
|
||||
) in combined_bench_names: # pylint: disable=consider-using-dict-items
|
||||
bench_score = accuracy.compute(
|
||||
references=combined_bench_names[bench_name]["refs"],
|
||||
predictions=combined_bench_names[bench_name]["preds"],
|
||||
)["accuracy"]
|
||||
bench_refs.extend(combined_bench_names[bench_name]["refs"])
|
||||
bench_preds.extend(combined_bench_names[bench_name]["preds"])
|
||||
if not pd.isna(bench_score):
|
||||
results[
|
||||
f"{bench_split}_bench_accuracy_{bench_name}"
|
||||
] = bench_score
|
||||
bench_scores.append(bench_score)
|
||||
else:
|
||||
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
|
||||
bench_scores.append(0.0)
|
||||
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
|
||||
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
|
||||
references=bench_refs, predictions=bench_preds
|
||||
)["accuracy"]
|
||||
trainer.log(results)
|
||||
|
||||
results = broadcast_dict(results)
|
||||
for key, val in results.items():
|
||||
metrics[key] = val
|
||||
|
||||
return BenchEvalCallback
|
||||
|
||||
|
||||
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
class LogPredictionCallback(TrainerCallback):
|
||||
"""Callback to log prediction values during each evaluation"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
|
||||
def on_evaluate(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
train_dataloader, # pylint: disable=unused-argument
|
||||
eval_dataloader,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
eval_table_size = self.cfg.eval_table_size
|
||||
|
||||
if eval_table_size <= 0:
|
||||
return control
|
||||
|
||||
trainer.model.eval()
|
||||
device = torch.device(self.cfg.device)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=self.cfg.eval_table_max_new_tokens,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
|
||||
def logits_to_tokens(logits) -> torch.Tensor:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
# Get the predicted token ids (the ones with the highest probability)
|
||||
predicted_token_ids = torch.argmax(probabilities, dim=-1)
|
||||
return predicted_token_ids
|
||||
|
||||
def find_ranges(lst):
|
||||
ranges = []
|
||||
start = 0
|
||||
for i in range(1, len(lst)):
|
||||
if lst[i] == 0:
|
||||
ranges.append((start, i - 1))
|
||||
start = i
|
||||
end = len(lst) - 1
|
||||
ranges.append((start, end))
|
||||
return ranges
|
||||
|
||||
def log_table_from_dataloader(name: str, table_dataloader):
|
||||
table = wandb.Table( # type: ignore[attr-defined]
|
||||
columns=[
|
||||
"id",
|
||||
"Prompt",
|
||||
"Correct Completion",
|
||||
"Predicted Completion (model.generate)",
|
||||
"Predicted Completion (trainer.prediction_step)",
|
||||
]
|
||||
)
|
||||
row_index = 0
|
||||
|
||||
for batch in tqdm(table_dataloader):
|
||||
if row_index > eval_table_size:
|
||||
break
|
||||
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
else:
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
|
||||
(_, batch_logits, _) = trainer.prediction_step(
|
||||
trainer.model,
|
||||
batch,
|
||||
prediction_loss_only=False,
|
||||
)
|
||||
|
||||
prompt_token_ids_list = []
|
||||
pred_step_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
|
||||
for input_ids_all, labels_all, pos_ids, logits in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
batch_logits,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
else:
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
)
|
||||
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
|
||||
pred_step_token_ids = logits_to_tokens(
|
||||
logits[start : end + 1]
|
||||
)[tokens_with_loss]
|
||||
pred_step_token_ids_list.append(pred_step_token_ids)
|
||||
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
pred_step_texts = tokenizer.batch_decode(
|
||||
pred_step_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(self.cfg.device)
|
||||
predictions = trainer.model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
)
|
||||
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
for (
|
||||
prompt_text,
|
||||
completion_text,
|
||||
prediction_text,
|
||||
pred_step_text,
|
||||
) in zip(
|
||||
prompt_texts, completion_texts, predicted_texts, pred_step_texts
|
||||
):
|
||||
table.add_data(
|
||||
row_index,
|
||||
prompt_text,
|
||||
completion_text,
|
||||
prediction_text,
|
||||
pred_step_text,
|
||||
)
|
||||
row_index += 1
|
||||
|
||||
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
|
||||
|
||||
if is_main_process():
|
||||
log_table_from_dataloader("Eval", eval_dataloader)
|
||||
|
||||
return control
|
||||
|
||||
return LogPredictionCallback
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to wandb"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||
artifact.add_file(local_path=self.axolotl_config_path)
|
||||
wandb.run.log_artifact(artifact)
|
||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user