Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
881d333b84 wip for new datasets abstractions
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-05 16:37:48 -04:00
122 changed files with 2307 additions and 8705 deletions

View File

@@ -53,13 +53,6 @@ body:
validations: validations:
required: true required: true
- type: textarea
id: config
attributes:
label: Config yaml
description: |
Please attach the config yaml!
- type: textarea - type: textarea
id: possible-solution id: possible-solution
attributes: attributes:

View File

@@ -25,11 +25,6 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -23,13 +23,7 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true runs-on: self-hosted
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -52,12 +46,9 @@ jobs:
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: | tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-runpod: build-axolotl-runpod:
needs: build-axolotl needs: build-axolotl
@@ -77,12 +68,7 @@ jobs:
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118 runs-on: self-hosted
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

16
.github/workflows/pre-commit.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
name: pre-commit
on:
pull_request:
push:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

View File

@@ -1,45 +0,0 @@
name: publish pypi
on:
push:
tags:
- '*'
jobs:
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip3 install wheel
pip3 install -e .
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py
run: >-
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a binary wheel
run: >-
python setup.py sdist bdist_wheel
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,32 +1,10 @@
name: Tests name: PyTest
on: on:
# check on push/merge to main, PRs, and manual triggers
push: push:
branches:
- "main"
paths:
- '**.py'
- 'requirements.txt'
pull_request: pull_request:
paths:
- '**.py'
- 'requirements.txt'
workflow_dispatch:
jobs: jobs:
pre-commit: test:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
pytest:
name: PyTest
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
@@ -46,35 +24,9 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install -U -e . pip install -e .
pip3 install -r requirements-tests.txt pip install -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |
pytest --ignore=tests/e2e/ tests/ pytest tests/
e2e-test:
name: E2E Tests
runs-on: [self-hosted, gpu]
timeout-minutes: 20
needs: [pre-commit, pytest]
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
# cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt
- name: Run e2e tests
run: |
pytest tests/e2e/

4
.gitignore vendored
View File

@@ -161,7 +161,3 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/ .idea/
# WandB
# wandb creates a folder to store logs for training runs
wandb

View File

@@ -1,3 +1,2 @@
[settings] [settings]
profile=black profile=black
known_third_party=wandb

View File

@@ -8,9 +8,6 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*] [mypy-axolotl.monkeypatch.*]
ignore_errors = True ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True
[mypy-flash_attn.*] [mypy-flash_attn.*]
ignore_missing_imports = True ignore_missing_imports = True
@@ -23,9 +20,6 @@ ignore_missing_imports = True
[mypy-peft] [mypy-peft]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-wandb]
ignore_missing_imports = True
[mypy-bitsandbytes] [mypy-bitsandbytes]
ignore_missing_imports = True ignore_missing_imports = True

562
README.md
View File

@@ -2,18 +2,6 @@
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with xformer, flash attention, rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb
- And more!
<table> <table>
<tr> <tr>
<td> <td>
@@ -23,12 +11,9 @@ Features:
- [Supported Features](#axolotl-supports) - [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-) - [Quickstart](#quickstart-)
- [Installation](#installation) - [Installation](#installation)
- [Docker](#docker) - [Docker Installation](#environment)
- [Conda/Pip venv](#condapip-venv) - [Conda/Pip venv Installation](#condapip-venv)
- [Runpod](#runpod) - [LambdaLabs Installation](#lambdalabs)
- [LambdaLabs](#lambdalabs)
- [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
@@ -52,7 +37,7 @@ Features:
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b> <b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p> </p>
<p> <p>
Go ahead and Axolotl questions!! Go ahead and axolotl questions!!
</p> </p>
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit"> <img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main"> <img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
@@ -66,17 +51,14 @@ Features:
## Axolotl supports ## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn | | | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|--------------| |----------|:----------|:-----|-------|------|-------------------|------------|---------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | | | ❌ | ❌ | ❌ | ❓ | | mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ | | falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | | ❓ | | gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | | ❓ |
| gpt-j | ✅ | | ✅ | | | ❓ | | | XGen | ✅ | | ✅ | | | ❓ | |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡
@@ -89,30 +71,27 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
pip3 install packaging pip3 install -e .[flash-attn]
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -U git+https://github.com/huggingface/peft.git pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora # finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --inference --lora_model_dir="./lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
``` ```
## Installation ## Installation
### Environment ### Environment
#### Docker - Docker
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
Or run on the current files for development: Or run on the current files for development:
@@ -120,48 +99,27 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d docker compose up -d
``` ```
<details> - Conda/Pip venv
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
* The `--privileged` flag gives all capabilities to the container.
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv
1. Install python >=**3.9** 1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install Axolotl along with python dependencies 3. Install python dependencies with ONE of the following:
- Recommended, supports QLoRA, NO gptq/int4 support
```bash ```bash
pip3 install packaging pip3 install -e .
pip3 install -e '.[flash-attn,deepspeed]' pip3 install -U git+https://github.com/huggingface/peft.git
``` ```
4. (Optional) Login to Huggingface to use gated models/datasets. - gptq/int4 support, NO QLoRA
```bash ```bash
huggingface-cli login pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
``` ```
Get the token at huggingface.co/settings/tokens
#### Runpod - LambdaLabs
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### LambdaLabs
<details> <details>
<summary>Click to Expand</summary> <summary>Click to Expand</summary>
@@ -193,10 +151,10 @@ Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runp
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
pip3 install packaging pip3 install -e . # change depend on needs
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install protobuf==3.20.3 pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy pip3 install -U --ignore-installed requests Pillow psutil scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
``` ```
5. Set path 5. Set path
@@ -205,30 +163,7 @@ Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runp
``` ```
</details> </details>
#### Windows - Windows: Please use WSL or Docker!
Please use WSL or Docker!
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
# Managed spot (auto-recovery on preemption)
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset ### Dataset
@@ -239,7 +174,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt`: conversations where `from` is `human`/`gpt` - `sharegpt:chat`: conversations where `from` is `human`/`gpt`
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -304,10 +239,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"article": "...", "question": "...", "answer": "..."} {"article": "...", "question": "...", "answer": "..."}
``` ```
- `context_qa.load_v2`: in context question answering (alternate)
```json
{"context": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context - `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json ```json
{"article": "...", "unanswerable_question": "..."} {"article": "...", "unanswerable_question": "..."}
@@ -332,11 +263,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"prompt": "...", "generation": "..."} {"prompt": "...", "generation": "..."}
``` ```
- `sharegpt.load_role`: conversations where `role` is used instead of `from` - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt - `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -349,28 +280,29 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
For a dataset that is preprocessed for instruction purposes: Using yaml. Example:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
```yaml ```yaml
datasets: datasets:
- path: repo - path: repo
type: type:
system_prompt: "" system_prompt: ""
field_system: system no_input_format: |-
format: "[INST] {instruction} [/INST]" User: {instruction}<|end_of_turn|>
no_input_format: "[INST] {instruction} [/INST]" Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
``` ```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels` - Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
### Config ### Config
@@ -397,7 +329,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: EleutherAI/pile - path: EleutherAI/pile
name: enron_emails name: enron_emails
type: completion # format from earlier type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets # huggingface repo with multiple named configurations/subsets
datasets: datasets:
@@ -408,30 +339,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ...
type: sharegpt
conversation: chatml
# local # local
datasets: datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
ds_type: json # see other options below ds_type: json # see other options below
type: alpaca type: alpaca
# dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
``` ```
- loading - loading
@@ -459,18 +371,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
<details> <details>
<summary>All yaml options (click me)</summary> <summary>All yaml options</summary>
```yaml ```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files # this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# This can also be a relative path to a model on disk # this can also be a relative path to a model on disk
base_model: ./llama-7b-hf base_model: ./llama-7b-hf
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) # you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns: base_model_ignore_patterns:
# If the base_model repo on hf hub doesn't include configuration .json files, # if the base_model repo on hf hub doesn't include configuration .json files,
# You can set that here, or leave this empty to default to base_model # you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub # you can specify to choose a specific model revision from huggingface hub
model_revision: model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
@@ -485,32 +397,18 @@ trust_remote_code:
tokenizer_use_fast: tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True # Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy: tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32 # resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models # this is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# Used to identify which the model is based on # whether you are training a 4-bit GPTQ quantized model
is_falcon_derived_model:
is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
# optional overrides to the base model configuration
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2 gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit: true
# Use bitsandbytes 4 bit # use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:
# Use CUDA bf16 # Use CUDA bf16
@@ -524,33 +422,28 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere bfloat16: true # require >=ampere
float16: true float16: true
# A list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # hf dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
data_files: # Optional[str] path to source data files data_files: # path to source data files
shards: # Optional[int] number of shards to split data into shards: # number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt # custom user prompt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# Custom user prompt
- path: repo - path: repo
type: type:
# The below are defaults. only set what's needed. # the below are defaults. only set what's needed.
system_prompt: "" system_prompt: ""
system_format: "{system}"
field_system: system field_system: system
field_instruction: instruction field_instruction: instruction
field_input: input field_output: input
field_output: output
# Customizable to be single line or multi-line # customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input} # 'format' can include {input}
format: |- format: |-
User: {instruction} {input} User: {instruction} {input}
@@ -558,24 +451,18 @@ datasets:
# 'no_input_format' cannot include {input} # 'no_input_format' cannot include {input}
no_input_format: "{instruction} " no_input_format: "{instruction} "
# For `completion` datsets only, uses the provided field instead of `text` column # axolotl attempts to save the dataset as an arrow after packing the data together so
field:
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub # push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # repo path to push finetuned model hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub # how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy: hub_strategy:
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# Required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04 val_set_size: 0.04
@@ -584,34 +471,28 @@ dataset_shard_num:
# Index of shard to use for whole dataset # Index of shard to use for whole dataset
dataset_shard_idx: dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048 # the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# Pad inputs so each step uses constant sized buffers # pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len: pad_to_sequence_len:
# Max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing: sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on. # you can set these packing optimizations AFTER starting a training at least once.
eval_sample_packing:
# You can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values. # The trainer will provide recommended values for these values.
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # if you already have a lora model trained that you want to load, put that here
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. # lora hyperparameters
lora_model_dir: lora_model_dir:
# LoRA hyperparameters
# For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
@@ -623,122 +504,76 @@ lora_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
lora_target_linear: # If true, will target all linear layers lora_target_linear: # if true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save: lora_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
# Once you complete training, the model will be saved to the following directory.
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# ReLoRA configuration # ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart relora_steps: # number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps relora_warmup_steps: # number of per-restart warmup steps
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # Your wandb project name wandb_project: # your wandb project name
wandb_entity: # A wandb Team name if using a Team wandb_entity: # a wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # Set the name of your wandb run wandb_run_id: # set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
# Whether to use torch.compile and which backend to use # training hyperparameters
torch_compile: # bool
torch_compile_backend: # Optional[str]
# Training hyperparameters
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size: 2
num_epochs: 4 num_epochs: 3
warmup_steps: 100 warmup_steps: 100
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # leave empty to save at each epoch
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps eval_steps: # leave empty to eval at each epoch
save_total_limit: # Checkpoints saved at a time save_total_limit: # checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps: max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 # save model as safetensors (require safetensors package)
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# Save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# Whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# Group similarly sized data to minimize padding. # group similarly sized data to minimize padding
# May be slower to start, as it must download and sort the entire dataset. # may be slower to start, as it must download and sort the entire dataset
# Note that training loss may have an oscillating pattern with this enabled. # note that training loss may have an oscillating pattern with this enabled
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# Stop training after this many evaluation losses have increased in a row # stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3 early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer # specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
# For one_cycle optim # for one_cycle optim
lr_div_factor: # Learning rate div factor lr_div_factor: # learning rate div factor
# For log_sweep optim # for log_sweep optim
log_sweep_min_lr: log_sweep_min_lr:
log_sweep_max_lr: log_sweep_max_lr:
# Specify optimizer # specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
#
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_hf
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adafactor
# - adamw_anyprecision
# - sgd
# - adagrad
# - adamw_bnb_8bit
# - lion_8bit
# - lion_32bit
# - paged_adamw_32bit
# - paged_adamw_8bit
# - paged_lion_32bit
# - paged_lion_8bit
optimizer: optimizer:
# Specify weight decay # specify weight decay
weight_decay: weight_decay:
# adamw hyperparams # adamw hyperparams
adam_beta1: adam_beta1:
@@ -747,54 +582,47 @@ adam_epsilon:
# Gradient clipping max norm # Gradient clipping max norm
max_grad_norm: max_grad_norm:
# Augmentation techniques # whether to bettertransformers
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha:
# Whether to bettertransformers
flash_optimum: flash_optimum:
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers: # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: # whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention: flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only # whether to use scaled-dot-product attention
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama) # Landmark attention (only llama)
landmark_attention: landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only # llama only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off. # if resume_from_checkpoint isn't set and you simply want it to start where it left off
# Be careful with this being turned on between different models. # be careful with this being turned on between different models
auto_resume_from_checkpoints: false auto_resume_from_checkpoints: false
# Don't mess with this, it's here for accelerate and torchrun # don't mess with this, it's here for accelerate and torchrun
local_rank: local_rank:
# Add or change special tokens. # add or change special tokens
# If you add tokens here, you don't need to add them to the `tokens` list.
special_tokens: special_tokens:
# bos_token: "<s>" # bos_token: "<s>"
# eos_token: "</s>" # eos_token: "</s>"
# unk_token: "<unk>" # unk_token: "<unk>"
# add extra tokens
# Add extra tokens.
tokens: tokens:
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:
# Deepspeed config path. e.g., deepspeed/zero3.json # Deepspeed config path
deepspeed: deepspeed:
# Advanced DDP Arguments # Advanced DDP Arguments
@@ -820,108 +648,21 @@ strict:
</details> </details>
<details>
<summary> Understanding of batch size and gradient accumulation steps </summary>
<br/>
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
**Example 1:**
Micro batch size: 3
Gradient accumulation steps: 2
Number of GPUs: 3
Total batch size = 3 * 2 * 3 = 18
```
| GPU 1 | GPU 2 | GPU 3 |
|----------------|----------------|----------------|
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|----------------|----------------|----------------|
| → (accumulate) | → (accumulate) | → (accumulate) |
|----------------|----------------|----------------|
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|----------------|----------------|----------------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
Weight update for w1:
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
```
**Example 2:**
Micro batch size: 2
Gradient accumulation steps: 1
Number of GPUs: 3
Total batch size = 2 * 1 * 3 = 6
```
| GPU 1 | GPU 2 | GPU 3 |
|-----------|-----------|-----------|
| S1, S2 | S3, S4 | S5, S6 |
| e1, e2 | e3, e4 | e5, e6 |
|-----------|-----------|-----------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
Weight update for w1:
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
```
</details>
### Train ### Train
Run Run
```bash ```bash
accelerate launch -m axolotl.cli.train your_config.yml accelerate launch scripts/finetune.py your_config.yml
```
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- Use `--debug` to see preprocessed examples.
```bash
python -m axolotl.cli.preprocess your_config.yml
``` ```
#### Multi-GPU #### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed You can optionally pre-tokenize dataset with the following before finetuning:
is the recommended multi-GPU option currently because FSDP may experience ```bash
[loss instability](https://github.com/huggingface/transformers/issues/26498). CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed/zero1.json
``` ```
```shell ##### Config
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
##### FSDP
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -934,6 +675,11 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```
- llama Deepspeed
```yaml
deepspeed: deepspeed/zero3.json
```
##### Weights & Biases Logging ##### Weights & Biases Logging
- wandb options - wandb options
@@ -952,44 +698,34 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir" --inference --lora_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model" --inference --base_model="./completed-model"
``` ```
- Full weights finetune w/ a prompt from a text file: - Full weights finetune w/ a prompt from a text file:
```bash ```bash
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
``` ```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base ### Merge LORA to base
Add below flag to train command above Add below flag to train command above
```bash ```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with
```bash ```bash
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ... CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
``` ```
## Common Errors 🧰 ## Common Errors 🧰
See also the [FAQ's](./docs/faq.md).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it: > If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below Please reduce any below
@@ -1016,10 +752,6 @@ Try to turn off xformers.
It's safe to ignore it. It's safe to ignore it.
> NCCL Timeouts during training
See the [NCCL](docs/nccl.md) guide.
## Need help? 🙋♂️ ## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you

View File

@@ -1,41 +0,0 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -1,45 +1,46 @@
{ {
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": { "offload_optimizer": {
"device": "cpu" "device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
}, },
"contiguous_gradients": true, "bf16": {
"overlap_comm": true "enabled": "auto"
}, },
"bf16": { "fp16": {
"enabled": "auto" "enabled": "auto",
}, "auto_cast": false,
"fp16": { "loss_scale": 0,
"enabled": "auto", "initial_scale_power": 32,
"auto_cast": false, "loss_scale_window": 1000,
"loss_scale": 0, "hysteresis": 2,
"initial_scale_power": 32, "min_loss_scale": 1
"loss_scale_window": 1000, },
"hysteresis": 2, "optimizer": {
"min_loss_scale": 1 "type": "AdamW",
}, "params": {
"optimizer": { "lr": "auto",
"type": "AdamW", "betas": [
"params": { 0.9,
"lr": "auto", 0.999
"betas": "auto", ],
"eps": "auto", "eps": 1e-8,
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupDecayLR", "type": "WarmupDecayLR",
"params": { "params": {
"warmup_min_lr": "auto", "warmup_min_lr": "auto",
"warmup_max_lr": "auto", "warmup_max_lr": "auto",
"warmup_num_steps": "auto", "warmup_num_steps": "auto",
"warmup_type": "linear", "total_num_steps": "auto"
"total_num_steps": "auto" }
} },
}, "train_batch_size": "auto",
"gradient_accumulation_steps": "auto", "train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto", "wall_clock_breakdown": false
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
} }

View File

@@ -1,6 +1,14 @@
{ {
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 0, "sub_group_size": 0,
@@ -28,21 +36,18 @@
"params": { "params": {
"lr": "auto", "lr": "auto",
"betas": "auto", "betas": "auto",
"eps": "auto", "eps": 1e-8,
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupDecayLR", "type": "WarmupLR",
"params": { "params": {
"warmup_min_lr": "auto", "warmup_min_lr": "auto",
"warmup_max_lr": "auto", "warmup_max_lr": "auto",
"warmup_num_steps": "auto", "warmup_num_steps": "auto"
"warmup_type": "linear",
"total_num_steps": "auto"
} }
}, },
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false

View File

@@ -9,11 +9,6 @@ services:
- ~/.cache/huggingface/:/root/.cache/huggingface/ - ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables # set environment variables
environment: environment:
# Set environment variables
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
- WANDB_API_KEY=${WANDB_API_KEY} - WANDB_API_KEY=${WANDB_API_KEY}
deploy: deploy:
resources: resources:

View File

@@ -5,9 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y vim curl apt-get install -y vim curl
@@ -15,19 +12,17 @@ RUN apt-get update && \
WORKDIR /workspace WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN cd axolotl && \
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[deepspeed,flash-attn]; \ pip install -e .[flash-attn,gptq]; \
fi fi
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN cd axolotl && \
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch git config --get remote.origin.fetch
# helper for huggingface-login cli # helper for huggingface-login cli

View File

@@ -10,28 +10,70 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9" ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1" ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118" ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
&& wget \
RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \ && bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \ && rm -f Miniconda3-latest-Linux-x86_64.sh
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
RUN git lfs install --skip-repo && \ FROM base-builder AS deepspeed-builder
pip3 install awscli && \
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
cd DeepSpeed && \
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel
FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
CUDA_VERSION=$CUDA make cuda11x && \
python setup.py bdist_wheel
FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
RUN pip3 install wheels/deepspeed-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -1,18 +0,0 @@
# Axolotl FAQ's
> The trainer stopped and hasn't progressed in several minutes.
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
> Exitcode -9
This usually happens when you run out of system RAM.
> Exitcode -7 while using deepspeed
Try upgrading deepspeed w: `pip install -U deepspeed`
> AttributeError: 'DummyOptim' object has no attribute 'step'
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.

View File

@@ -1,45 +0,0 @@
# Multi Node
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
~/.cache/huggingface/accelerate/default_config.yaml
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
main_process_ip: 10.0.0.4 # Set to main machine's IP
main_process_port: 5000
main_training_function: main
mixed_precision: bf16
num_machines: 2 # Change to the number of machines
num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
Configure your model to use FSDP with for example:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Machine configuration
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.

View File

@@ -1,51 +0,0 @@
# Multipack
4k context, bsz =4,
each character represents 256 tokens
X represents a padding token
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B ]
C C C C C C C ]
D D D D ]]
[[ E E E E E E E E ]
[ F F F F ]
[ G G G ]
[ H H H H ]]
[[ I I I ]
[ J J J ]
[ K K K K K]
[ L L L ]]
```
after padding to longest input in each step
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B X X X X X X ]
C C C C C C C X X X X ]
D D D D X X X X X X X ]]
[[ E E E E E E E E ]
[ F F F F X X X X ]
[ G G G X X X X X ]
[ H H H H X X X X ]]
[[ I I I X X ]
[ J J J X X ]
[ K K K K K ]
[ L L L X X ]]
```
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A B B B B B
B C C C C C C C D D D D E E E E
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```

View File

@@ -1,46 +0,0 @@
# NCCL
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
```text
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
```
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
```shell
nvidia-smi nvlink --status
```
To force NCCL to use NVLink, simply set this in the environment:
```shell
export NCCL_P2P_LEVEL=NVL
```
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
| NCCL_P2P_LEVEL | Description |
| -------------- | ----------- |
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
```shell
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
```shell
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
```
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.

View File

@@ -1,89 +0,0 @@
base_model: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer
trust_remote_code: true
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_prepared_run
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
sample_packing: false
sample_packing_eff_est:
sample_packing_seq_len_multiplier:
total_num_tokens:
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000085
train_on_inputs: true
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 32
eval_steps:
save_steps:
save_total_limit:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
pad_token: "<|endoftext|>"
fsdp:
# - full_shard
# - auto_wrap
fsdp_config:
# fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock

View File

@@ -1,4 +1,5 @@
base_model: cerebras/Cerebras-GPT-1.3B base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -6,8 +7,8 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -49,7 +50,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,13 +11,12 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 100000
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,16 +11,15 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 100000
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,13 +11,12 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 100000
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,16 +11,15 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 100000
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,13 +11,12 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 100000
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,16 +11,15 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 100000
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,8 +1,8 @@
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false
@@ -11,8 +11,8 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat type: alpaca:chat
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,11 +1,11 @@
# 1b: tiiuae/falcon-rw-1b # 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b # 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main # required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
# enable 4bit for QLoRA # enable 4bit for QLoRA
load_in_4bit: true load_in_4bit: true
@@ -17,8 +17,8 @@ datasets:
data_files: data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json - Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat" type: "alpaca:chat"
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -53,7 +53,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
num_epochs: 4 num_epochs: 3
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

View File

@@ -1,8 +1,8 @@
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false
@@ -11,8 +11,8 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat type: alpaca:chat
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
adapter: adapter:
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,4 +1,5 @@
base_model: EleutherAI/gpt-j-6b base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -6,8 +7,8 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -46,7 +47,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,11 +1,12 @@
base_model: huggyllama/llama-7b base_model: huggyllama/llama-7b
base_model_config: huggyllama/llama-7b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
datasets: datasets:
- path: openaccess-ai-collective/jeopardy - path: openaccess-ai-collective/jeopardy
type: jeopardy type: jeopardy
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: lora_model_dir:
@@ -24,7 +25,7 @@ wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -9,16 +9,12 @@ gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
```shell ```shell
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
``` ```
or or
```shell ```shell
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml accelerate launch scripts/finetune.py examples/llama-2/lora.yml
```
To launch a full finetuning with 16-bit precision:
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
``` ```

View File

@@ -1,72 +0,0 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,7 +1,8 @@
base_model: TheBloke/Llama-2-7B-GPTQ base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false is_llama_derived_model: false
gptq: true gptq: true
gptq_disable_exllama: true gptq_bits: 4
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true tokenizer_use_fast: true
@@ -14,8 +15,8 @@ hf_use_auth_token: true
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
@@ -37,7 +38,7 @@ wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_torch optimizer: adamw_torch
adam_beta2: 0.95 adam_beta2: 0.95
adam_eps: 0.00001 adam_eps: 0.00001
@@ -61,6 +62,8 @@ xformers_attention:
flash_attention: flash_attention:
sdp_attention: sdp_attention:
flash_optimum: flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100 warmup_steps: 100
eval_steps: eval_steps:
save_steps: save_steps:

View File

@@ -1,4 +1,5 @@
base_model: NousResearch/Llama-2-7b-hf base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,13 +11,12 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -34,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,9 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: NousResearch/Llama-2-7b-hf base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,8 +11,8 @@ strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -19,7 +20,6 @@ lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -36,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,8 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size:
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: NousResearch/Llama-2-7b-hf base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -10,8 +11,8 @@ strict: false
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
@@ -19,7 +20,6 @@ lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
@@ -40,7 +40,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -60,7 +60,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: 50 save_steps: 50
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,69 +0,0 @@
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/context-aware-splits-english
type: alpaca
dataset_prepared_path:
val_set_size: 200
output_dir: ./tiny-llama
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 50
eval_steps: 0.05
eval_table_size:
save_steps: 0.50
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,12 +0,0 @@
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
Fine Tune:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
```
If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
```

View File

@@ -1,61 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.000005
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,78 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,11 +1,12 @@
base_model: mosaicml/mpt-7b base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false load_in_8bit: false
datasets: datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: lora_model_dir:
@@ -26,7 +27,7 @@ wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -1,4 +1,5 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -8,12 +9,12 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 256
sample_packing: true max_packed_sequence_len:
lora_r: lora_r:
lora_alpha: lora_alpha:
lora_dropout: lora_dropout:
@@ -28,11 +29,11 @@ wandb_log_model:
output_dir: ./openllama-out output_dir: ./openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.000003 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
float16: true float16: true
@@ -44,12 +45,12 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention: true
flash_attention: true flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 10
eval_steps: 0.05 eval_steps: 50
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: true load_in_8bit: true
@@ -8,12 +9,12 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 256
sample_packing: true max_packed_sequence_len:
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.0 lora_dropout: 0.0
@@ -32,9 +33,9 @@ wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-out output_dir: ./lora-out
gradient_accumulation_steps: 1 batch_size: 16
micro_batch_size: 2 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
@@ -49,16 +50,16 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention: true
flash_attention: true flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 10
eval_steps: 0.05 eval_steps: 50
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

View File

@@ -1,4 +1,5 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -8,12 +9,12 @@ push_dataset_to_hub:
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 2048
sample_packing: true max_packed_sequence_len: 2048
lora_r: 8 lora_r: 8
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05
@@ -26,33 +27,33 @@ wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 1 batch_size: 4
micro_batch_size: 2 micro_batch_size: 4
num_epochs: 4 num_epochs: 2
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: false bf16: true
fp16: true fp16: false
tf32: false tf32: true
gradient_checkpointing: true gradient_checkpointing: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: xformers_attention: true
flash_attention: true flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

View File

@@ -1,11 +0,0 @@
# Phi
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
```shell
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
# OR
python -m axolotl.cli.train examples/phi/phi-qlora.yml
```

View File

@@ -1,74 +0,0 @@
base_model: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len:
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -1,74 +0,0 @@
base_model: microsoft/phi-1_5
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 1024
sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len:
adapter: qlora
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -1,4 +1,5 @@
base_model: EleutherAI/pythia-12b-deduped base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -9,7 +10,7 @@ device_map: auto
datasets: datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
adapter: adapter:
lora_model_dir: lora_model_dir:

View File

@@ -1,9 +1,10 @@
base_model: EleutherAI/pythia-1.4b-deduped base_model: EleutherAI/pythia-1.4b-deduped
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true load_in_8bit: true
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -23,15 +24,15 @@ wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
learning_rate: 0.00001 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: true bf16: True
tf32: true tf32: True
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
weight_decay: 0.1 weight_decay: 0.1
eval_steps: 0.05 eval_steps: 20
logging_steps: 1 logging_steps: 1

View File

@@ -1,4 +1,5 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
@@ -6,7 +7,7 @@ load_in_8bit: false
datasets: datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: lora_model_dir:
@@ -27,7 +28,7 @@ wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -1,10 +1,11 @@
base_model: replit/replit-code-v1-3b base_model: replit/replit-code-v1-3b
base_model_config: replit/replit-code-v1-3b
trust_remote_code: true trust_remote_code: true
load_in_8bit: false load_in_8bit: false
datasets: datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -26,7 +27,7 @@ wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: optimizer:
torchdistx_path: torchdistx_path:
lr_scheduler: lr_scheduler:

View File

@@ -1,6 +1,7 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora # An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset. # on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base base_model: Salesforce/xgen-7b-8k-base
base_model_config: Salesforce/xgen-7b-8k-base
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -15,8 +16,8 @@ datasets:
data_files: data_files:
- openassistant_best_replies_train.jsonl - openassistant_best_replies_train.jsonl
type: "completion" type: "completion"
dataset_prepared_path: dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -51,7 +52,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
num_epochs: 4 num_epochs: 3
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -1,23 +1,23 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1 torch==2.0.1
auto-gptq==0.4.2 auto-gptq
packaging packaging
peft==0.6.0 peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697 transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
deepspeed
addict addict
evaluate
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets>=2.14.0 datasets
flash-attn>=2.3.0 flash-attn>=2.0.8
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.22 xformers
optimum==1.13.2 optimum
hf_transfer hf_transfer
colorama colorama
numba numba
@@ -30,11 +30,3 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29
gradio
tensorboard
# remote filesystems
s3fs
gcsfs
# adlfs

View File

@@ -1,38 +1,262 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging import logging
import os
import random
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch
import transformers import transformers
import yaml
from axolotl.cli import ( # add src to the pythonpath so we don't need to pip install this
check_accelerate_default_config, from art import text2art
check_user_token, from transformers import GenerationConfig, TextStreamer
do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
LOG = logging.getLogger("axolotl.scripts.finetune") from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process():
print(ascii_art)
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art() print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
)
)
)
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
@@ -45,6 +269,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -7,28 +7,17 @@ def parse_requirements():
_install_requires = [] _install_requires = []
_dependency_links = [] _dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file: with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()] lines = [
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
]
for line in lines: for line in lines:
if line.startswith("--extra-index-url"): if line.startswith("--extra-index-url"):
# Handle custom index URLs # Handle custom index URLs
_, url = line.split() _, url = line.split()
_dependency_links.append(url) _dependency_links.append(url)
elif ( elif "flash-attn" not in line and line and line[0] != "#":
"flash-attn" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
):
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
return _install_requires, _dependency_links return _install_requires, _dependency_links
@@ -37,18 +26,20 @@ install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
version="0.3.0", version="0.1",
description="LLM Trainer", description="You know you're going to axolotl questions",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages(), packages=find_packages(),
install_requires=install_requires, install_requires=install_requires,
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "gptq": [
"flash-attn>=2.3.0", "auto-gptq",
], ],
"deepspeed": [ "flash-attn": [
"flash-attn==2.0.8",
],
"extras": [
"deepspeed", "deepspeed",
], ],
}, },

View File

@@ -1,358 +0,0 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
import random
import sys
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
import gradio as gr
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
if is_main_process():
print(ascii_art)
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -1,36 +0,0 @@
"""
CLI to run inference on a trained model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,27 +0,0 @@
"""
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,53 +0,0 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET
)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,42 +0,0 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.scripts")
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,38 +0,0 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -25,22 +25,11 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,5 +0,0 @@
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -0,0 +1,144 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Union
from datasets import Dataset as Dataset_ds
from datasets import DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
logger = logging.getLogger("axolotl")
class DsType(Enum):
JSON = "json"
ARROW = "arrow"
PARQUET = "parquet"
@dataclass
class DatasetConfiguration:
path: str
type: str
name: Optional[str] = field(
default=None,
metadata={"help": "the name of the dataset configuration to load."},
)
ds_type: Optional[DsType] = None
data_files: Optional[Union[str, List[str]]] = None
shards: Optional[int] = None
test_size: Optional[float] = None
@staticmethod
def from_dict(d: Dict[str, Any]) -> Generator["DatasetConfiguration", None, None]:
if "name" in d and isinstance(d["name"], list):
name = d.pop("name")
for n in name:
yield DatasetConfiguration(
**d,
name=n,
)
def load_dataset_from_local(config: DatasetConfiguration) -> Optional[Dataset_ds]:
local_path = Path(config.path)
if not local_path.exists():
return None
ds = None
if local_path.is_dir():
if config.ds_type:
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_from_disk(config.path)
else:
ds = load_dataset(
config.path,
name=config.name,
data_files=config.data_files,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = "json"
if config.ds_type:
ds_type = config.ds_type.value
elif "parquet" in config.path:
ds_type = "parquet"
elif "arrow" in config.path:
ds_type = "arrow"
ds = load_dataset(
ds_type,
name=config.name,
data_files=config.path,
streaming=False,
split=None, # is this correct?
)
if not ds:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
return ds
# TODO should this be a DatasetDict?
class Dataset(Dataset_ds):
_config: DatasetConfiguration
def __init__(self, *args, config: DatasetConfiguration = None, **kwargs):
self._config = config
super().__init__(*args, **kwargs)
@staticmethod
def from_config(
config: DatasetConfiguration,
token: bool = False,
default_test_size: float = 0.1,
):
ds = load_dataset_from_local(config)
if not ds:
try:
ds = load_dataset(
config.path,
name=config.name,
data_files=config.data_files,
token=token,
)
except FileNotFoundError:
pass
if not ds:
fp = hf_hub_download(
repo_id=config.path,
repo_type="dataset",
filename=config.data_files,
token=token,
)
ds = load_dataset(
"json", name=config.name, data_files=fp, streaming=False, split=None
)
if not ds:
raise ValueError("unhandled dataset load")
test_size = config.test_size if config.test_size else default_test_size
# determine if the dataset is pre-tokenized
check_ds = ds["train"] if isinstance(ds, DatasetDict) and "train" in ds else ds
is_ds_tokenized = False
if "input_ids" in check_ds.features:
is_ds_tokenized = True
if "attention_mask" not in check_ds.features:
logger.warning("`attention_mask` missing from pre-tokenized dataset")
if "labels" not in check_ds.features:
logger.warning("`labels` missing from pre-tokenized dataset")
if test_size and (not isinstance(ds, DatasetDict) or "test" not in ds):
ds.train_test_split(test_size=test_size, shuffle=False)
pass
class DatasetCollection:
datasets: List[Dataset] = []
def __init__(self, datasets: Union[Dataset, List[Dataset]]):
self.datasets = datasets if isinstance(datasets, list) else [datasets]
def __iter__(self):
for ds in self.datasets:
for d in ds:
yield d

View File

@@ -1,748 +0,0 @@
"""
Builder for the training args and trainer
"""
import abc
import importlib
import logging
import math
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional
import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import seed_worker
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=(
self.train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=(
eval_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@abstractmethod
def build(self, total_num_steps):
pass
@abstractmethod
def get_callbacks(self):
pass
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def get_callbacks(self):
callbacks = []
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_arguments_kwargs = {}
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = self.cfg.bf16
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if self.cfg.seed:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
if self.cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
if self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
elif torch._dynamo: # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
training_arguments_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(self.model, self.tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [self.train_dataset, self.eval_dataset]:
dataset = dataset.map(
partial(
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
),
batched=False,
num_proc=32,
)
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer

View File

@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import List, Optional from typing import List
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
""" """
Dataset that returns tokenized prompts from a stream of text files. Dataset that returns tokenized prompts from a stream of text files.
Args: Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
""" """
@@ -30,29 +30,18 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs, **kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs) super().__init__(self.process(dataset).data, **kwargs)
def process(self, dataset): def process(self, dataset):
features = dataset.features.keys() features = dataset.features.keys()
num_proc = ( num_proc = min(64, os.cpu_count())
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map( return dataset.map(
self.prompt_tokenizer.tokenize_prompt, self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc, num_proc=num_proc,
remove_columns=features, remove_columns=features,
**map_kwargs,
) )
@@ -61,7 +50,7 @@ class ConstantLengthDataset(IterableDataset):
""" """
Iterable dataset that returns constant length chunks of tokens from stream of text files. Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args: Args:
tokenizer (Tokenizer): The processor used for processing the data. tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return. seq_length (int): Length of token sequences to return.
""" """

View File

@@ -23,7 +23,6 @@ class ColorfulFormatter(Formatter):
} }
def format(self, record): def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
log_message = super().format(record) log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
@@ -36,7 +35,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
}, },
"colorful": { "colorful": {
"()": ColorfulFormatter, "()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
}, },
}, },
"filters": {}, "filters": {},

View File

@@ -1,6 +0,0 @@
"""
MixFormers model architecture used for phi models
"""
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa

View File

@@ -1,63 +0,0 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Any, Dict, List, Optional, Union
from transformers import PretrainedConfig
class MixFormerSequentialConfig(PretrainedConfig):
"""MixFormer (sequential for DeepSpeed) configuration."""
model_type = "mixformer-sequential"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
"blocks": "architecture", # `blocks` key is for backward compatibility
}
def __init__(
self,
vocab_size: Optional[int] = 50304,
n_positions: Optional[int] = 2048,
n_embd: Optional[int] = 1024,
n_layer: Optional[int] = 20,
n_inner: Optional[int] = None,
n_head: Optional[int] = 16,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
embd_layer: Optional[str] = "default",
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
embd_pdrop: Optional[float] = 0.0,
resid_pdrop: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-5,
initializer_range: Optional[float] = 0.02,
tie_word_embeddings: Optional[bool] = False,
pad_vocab_size_multiple: Optional[int] = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.embd_layer = embd_layer
self.architecture = architecture
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@@ -1,930 +0,0 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
from .configuration_mixformer_sequential import MixFormerSequentialConfig
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
Adapted from https://github.com/Dao-AILab/flash-attention."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[torch.Tensor] = None
class Embedding(nn.Module):
"""Token embedding with dropout."""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.wte(input_ids)
hidden_states = self.drop(hidden_states)
return hidden_states
class RotaryEmbedding(nn.Module):
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__(
self,
dim: int,
base: Optional[int] = 10000,
scale_base: Optional[float] = None,
device: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
if scale_base is not None:
raise NotImplementedError
# Generate and save the inverse frequency buffer (non-trainable)
self.dim = dim
self.base = base
self.scale_base = scale_base
self.device = device
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
self.register_buffer("inv_freq", inv_freq)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _update_cos_sin_cache(
self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
) -> None:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
seqlen = x.shape[1] + seqlen_offset
# Re-generate the inverse frequency buffer if it's not fp32
# (for instance if model.half() was called)
if self.inv_freq.dtype != "torch.float32":
self.inv_freq = 1.0 / (
self.base
** (
torch.arange(
0, self.dim, 2, device=self.device, dtype=torch.float32
)
/ self.dim
)
)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(
t, self.inv_freq.to(device=t.device, dtype=torch.float32)
)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
else:
power = (
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(
power, "s -> s 1"
)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def apply_rotary_emb_qkv(
self,
qkv: torch.FloatTensor,
sin: torch.FloatTensor,
cos: torch.FloatTensor,
sin_k: Optional[torch.FloatTensor] = None,
cos_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, three, _, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
# Splits the queries and keys in half
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
sin[:seqlen], "s d -> s 1 d"
)
# Casts to fp32 are necessary to prevent fp16 overflow issues
q1, q2, k1, k2, c, s = [
t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
]
# Computes the new keys and queries, recasting to original dtype
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
def forward(
self, qkv: torch.Tensor, seqlen_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform the forward pass.
Args:
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
Returns:
New `qkv` and the cached sinusoids.
"""
self._update_cos_sin_cache(qkv, seqlen_offset)
return self.apply_rotary_emb_qkv(
qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
)
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_sequence_len,
2,
num_heads,
head_dim,
dtype=kv.dtype,
device=kv.device,
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
)
assert sequence_end <= (
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
)
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
class MLP(nn.Module):
"""Multi-Layer Perceptron.
Reference:
Attention Is All You Need.
https://arxiv.org/pdf/1706.03762.pdf.
"""
def __init__(
self,
config: PretrainedConfig,
n_inner: Optional[int] = None,
act_fn: Optional[str] = None,
) -> None:
super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
self.fc1 = nn.Linear(config.n_embd, n_inner)
self.fc2 = nn.Linear(n_inner, config.n_embd)
self.act = ACT2FN[act_fn]
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
old_keys = [
prefix + "fc_in.weight",
prefix + "fc_out.weight",
prefix + "fc_in.bias",
prefix + "fc_out.bias",
]
new_keys = [
prefix + "fc1.weight",
prefix + "fc2.weight",
prefix + "fc1.bias",
prefix + "fc2.bias",
]
if all(k in state_dict for k in old_keys) and not all(
k in state_dict for k in new_keys
):
# Older version of `MLP` saved with different key names.
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class FusedMLP(nn.Module):
"""Fused Multi-Layer Perceptron from `flash-attn`.
Reference:
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
"""
def __init__(
self,
config: PretrainedConfig,
n_inner: Optional[int] = None,
act_fn: Optional[str] = None,
raise_on_missing: bool = False,
) -> None:
super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.mlp(hidden_states)
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(
self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
causal = self.causal if causal is None else causal
if cu_seqlens is not None:
return flash_attn_varlen_qkvpacked_func(
qkv.squeeze(0),
cu_seqlens,
max_seqlen,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
else:
return flash_attn_qkvpacked_func(
qkv,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
class CrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
causal = self.causal if causal is None else causal
return flash_attn_kvpacked_func(
q,
kv,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
def find_mha_dims(
config: PretrainedConfig,
n_head: Optional[int] = None,
head_dim: Optional[int] = None,
) -> Tuple[int, int]:
"""Validate and return the number of heads and head dimension for multi-head attention.
Args:
config: Model configuration.
n_head: Number of heads.
head_dim: Head dimension.
Returns:
Number of heads and head dimension.
"""
assert all(
hasattr(config, attr) for attr in ["n_embd", "n_head"]
), "`config` must have `n_embd` and `n_head` attributes."
if head_dim is None:
assert (
config.n_embd % config.n_head == 0
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
if n_head is None and head_dim is None:
head_dim = config.n_embd // config.n_head
n_head = config.n_head
elif n_head is None or head_dim is None:
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
return n_head, head_dim
class MHA(nn.Module):
"""Multi-head attention layer.
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__(
self,
config: PretrainedConfig,
rotary_dim: Optional[int] = None,
n_head: Optional[int] = None,
head_dim: Optional[int] = None,
bias: Optional[bool] = True,
dropout: Optional[float] = 0.0,
softmax_scale: Optional[float] = None,
causal: Optional[bool] = True,
layer_idx: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None,
return_residual: Optional[bool] = False,
checkpointing: Optional[bool] = False,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
fused_dense: Optional[bool] = True,
flash_attn: Optional[bool] = True,
cutlass_attn: Optional[bool] = False,
flash_rotary: Optional[bool] = True,
raise_on_missing: Optional[bool] = False,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
self.hidden_size = config.n_embd
self.n_head = n_head
self.head_dim = head_dim
self.op_size = n_head * head_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = (
rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
)
self.fused_dense = fused_dense
self.flash_attn = flash_attn
self.cutlass_attn = cutlass_attn
self.flash_rotary = flash_rotary
self.return_residual = return_residual
self.checkpointing = checkpointing
if self.rotary_emb_dim > 0:
rotary_kwargs = {"device": device}
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
else:
pass
self.Wqkv = nn.Linear(
self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
)
self.out_proj = nn.Linear(
self.op_size, self.hidden_size, bias=bias, **factory_kwargs
)
self.inner_attn = SelfAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_cross_attn = CrossAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
def _update_kv_cache(
self, kv: torch.FloatTensor, inference_params: InferenceParams
) -> None:
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
assert (
self.layer_idx is not None
), "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(
self,
x: torch.FloatTensor,
x_kv: Optional[torch.FloatTensor] = None,
key_padding_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
mixer_subset: Optional[torch.LongTensor] = None,
past_cache: Optional[InferenceParams] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Perform the forward pass.
Args:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
past_cache: For generation only.
Returns:
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
else (total, hidden_dim) where total is the is the sum of the sequence lengths
in the batch.
"""
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.flash_attn
# assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.flash_attn
if past_cache is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
attn_kwargs = {"key_padding_mask": key_padding_mask}
assert x_kv is None and mixer_subset is None
qkv = self.Wqkv(x)
qkv = rearrange(
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
)
if past_cache is None:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
context = self.inner_attn(
qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
)
else:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if past_cache.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
out = rearrange(context, "... h d -> ... (h d)")
out = self.out_proj(out)
return out if not self.return_residual else (out, x)
class ParallelBlock(nn.Module):
"""Parallel block.
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
"""
def __init__(
self,
config: PretrainedConfig,
mixer: Optional[Dict[str, Any]] = None,
mlp: Optional[Dict[str, Any]] = None,
block_idx: Optional[int] = None,
) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx
self.mixer = MHA(config, layer_idx=block_idx)
self.mlp = MLP(config)
def forward(
self,
hidden_states: torch.FloatTensor,
past_cache: Optional[torch.FloatTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
hidden_states,
past_cache=past_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0]
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
class CausalLMHead(nn.Module):
"""Causal Language Modeling head.
Reference:
Improving Language Understanding by Generative Pre-Training.
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
"""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.linear = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.ln(hidden_states)
logits = self.linear(hidden_states).to(torch.float32)
return logits
class CausalLMLoss(nn.Module):
"""Causal Language Modeling loss.
Reference:
Improving Language Understanding by Generative Pre-Training.
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
"""
def __init__(self, shift_labels: Optional[bool] = True) -> None:
super().__init__()
self.shift_labels = shift_labels
self.loss_fct = nn.CrossEntropyLoss()
def forward(
self, logits: torch.FloatTensor, labels: torch.LongTensor
) -> torch.FloatTensor:
if self.shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return loss
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
config_class = MixFormerSequentialConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs) -> None:
super().__init__(*inputs, **kwargs)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, **kwargs
) -> Dict[str, Any]:
if "use_cache" in kwargs and not kwargs["use_cache"]:
return {"input_ids": input_ids}
if past_key_values is None or not (
isinstance(past_key_values, InferenceParams)
):
past_key_values = InferenceParams(
max_batch_size=input_ids.shape[0],
max_sequence_len=self.config.n_positions,
sequence_len_offset=0,
batch_size_offset=0,
fused_ft_kernel=False,
key_value_memory_dict={},
)
else:
# assume past_key_values has cached all but last token in input_ids
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
class PackedSequential(nn.Sequential):
def forward(
self,
input,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
):
for module in self:
sig = inspect.signature(module.forward)
if "cu_seqlens" in sig.parameters:
input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
else:
input = module(input)
return input
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
_keys_to_ignore_on_load_missing = [""]
_keys_to_ignore_on_load_unexpected = [
r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
]
_no_split_modules = ["ParallelBlock"]
def __init__(self, config: MixFormerSequentialConfig) -> None:
super().__init__(config)
modules = [Embedding(config)]
block_config = config.architecture
if not isinstance(block_config, list):
block_config = [block_config for _ in range(config.n_layer)]
if config.n_layer != len(block_config):
config.n_layer = len(block_config)
for block_idx, block in enumerate(block_config):
# `block_cls` with `legacy` value is for backward compatibility
# `path` key is for backward compatibility
block = copy.deepcopy(block) or {"block_cls": "parallel"}
block.pop("path", None) or block.pop("block_cls", None)
block["block_idx"] = block_idx
modules.append(ParallelBlock(config, **block))
modules.append(CausalLMHead(config))
self.layers = PackedSequential(*modules)
self.loss = CausalLMLoss()
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.layers[0].wte
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.layers[0].wte = new_embeddings
def get_output_embeddings(self) -> nn.Linear:
return self.layers[-1].linear
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.layers[-1].linear = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
cu_seqlens: Optional[torch.LongTensor] = None
max_seqlen: Optional[int] = None
if position_ids is not None:
batch_size, seq_length = input_ids.shape
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if not past_key_values:
lm_logits = self.layers(
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
else:
hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]:
hidden_layer = module(
hidden_layer,
past_cache=past_key_values,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
lm_logits = self.layers[-1](hidden_layer)
loss = None
if labels is not None:
loss = self.loss(lm_logits, labels)
return CausalLMOutputWithPast(
loss=loss, logits=lm_logits, past_key_values=past_key_values
)

View File

@@ -1,66 +0,0 @@
"""
Flash attention monkey patch for cerebras btlm model
"""
import importlib
import logging
from typing import Optional, Tuple
import torch
from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
LOG = logging.getLogger("axolotl")
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)
modeling_btlm = importlib.import_module(module_name)
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
flashattn_attn
)
def flashattn_attn(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
softmax_scale = (
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
# Perform Flash attention
attn_output = flash_attn_func(
query,
key,
value,
dropout_p=0.0, # Assuming you have this attribute
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
causal=not self.is_cross_attention, # Assuming you have this attribute
return_attn_probs=False, # Set this based on your needs
)
# Optional: Apply head mask if it's not None
if head_mask is not None:
attn_output *= head_mask
attn_output = attn_output.permute(0, 2, 1, 3)
return attn_output, None # We don't have explicit attn_weights in Flash attention

View File

@@ -1,174 +0,0 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
yield "", system_prompt
else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message:
yield role + " ", message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -13,18 +13,12 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -44,33 +38,7 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def replace_llama_mlp_with_swiglu(model): def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask _prepare_decoder_attention_mask
) )
@@ -81,124 +49,34 @@ def replace_llama_attn_with_flash_attn(
llama_model_forward llama_model_forward
) )
# skip only if explicitly disabled try:
if cross_entropy: from flash_attn.losses.cross_entropy import CrossEntropyLoss
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy") LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
) )
self.o_proj = o except ImportError:
LOG.info(
# overwrite initialized weights with pretrained weights "optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
) )
def _post_training(self, model, name): try:
q_proj, k_proj, v_proj = torch.split( from flash_attn.ops.rms_norm import RMSNorm
self.qkv_proj.weight.data, self.out_features, dim=0
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
) )
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
@@ -221,7 +99,6 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -261,14 +138,9 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
if isinstance(self, FusedAttention): query_states = self.q_proj(hidden_states)
query_states, key_states, value_states = self.qkv_proj(hidden_states).split( key_states = self.k_proj(hidden_states)
self.out_features, dim=-1 value_states = self.v_proj(hidden_states)
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim
@@ -321,9 +193,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) if cu_seqlens is not None and max_seqlen is not None:
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
[query_states, key_states, value_states], dim=2 [query_states, key_states, value_states], dim=2
@@ -332,12 +202,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -360,7 +225,7 @@ def flashattn_forward(
qkv_unpad, qkv_unpad,
cu_seqlens_q, cu_seqlens_q,
max_seqlen_q, max_seqlen_q,
dropout_p=dropout_rate, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
) )
@@ -373,7 +238,6 @@ def flashattn_forward(
output = flash_attn_kvpacked_func( output = flash_attn_kvpacked_func(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal, causal=is_causal,
) )
else: else:
@@ -397,8 +261,6 @@ def flashattn_forward(
if attention_mask is not None if attention_mask is not None
else None, else None,
) )
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func( output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad, q_unpad,
kv_unpad, kv_unpad,
@@ -406,7 +268,7 @@ def flashattn_forward(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p=dropout_rate, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
) )
@@ -612,13 +474,6 @@ def llama_model_forward(
dtype=torch.bool, dtype=torch.bool,
device=inputs_embeds.device, device=inputs_embeds.device,
) )
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = ( attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask, attention_mask,
@@ -653,9 +508,7 @@ def llama_model_forward(
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module( return module(*inputs)
*inputs,
)
return custom_forward return custom_forward
@@ -664,10 +517,9 @@ def llama_model_forward(
hidden_states, hidden_states,
attention_mask, attention_mask,
position_ids, position_ids,
past_key_value, None,
output_attentions, output_attentions,
None, None,
padding_mask,
cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen,
) )
@@ -679,7 +531,6 @@ def llama_model_forward(
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
) )
@@ -726,7 +577,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[ ) -> Tuple[
@@ -759,7 +609,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
) )

View File

@@ -25,8 +25,6 @@ def sdp_attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()

View File

@@ -29,8 +29,6 @@ def xformers_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()

View File

@@ -1,643 +0,0 @@
"""Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code
import logging
from typing import List, Optional, Tuple, Union
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,
):
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer
)
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
mistral_model_forward
)
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask
def flashattn_forward(
self: OriginalMistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if (
hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
window_size=window_size,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
window_size=window_size,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def mistral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MistralDecoderLayer(OriginalMistralDecoderLayer):
"""
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -1,65 +0,0 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)
embeddings._old_forward = old_forward # pylint: disable=protected-access
return model
def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha
def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)
return embeddings
def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)

View File

@@ -1,415 +0,0 @@
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
""" PyTorch StableLM Epoch model. """
import importlib
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from accelerate import init_empty_weights
from einops import rearrange
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_varlen_qkvpacked_func,
)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
logger = logging.get_logger(__name__)
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_stablelm_epoch to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
)
modeling_stablelm = importlib.import_module(module_name)
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
flashattn_attn
)
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
stablelm_model_forward
)
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
decoder_layer_forward
)
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
# pylint: disable=invalid-name
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
# pylint: disable=invalid-name
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flashattn_attn(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
query_rot = query_states[..., : self.rotary_ndims]
query_pass = query_states[..., self.rotary_ndims :]
key_rot = key_states[..., : self.rotary_ndims]
key_pass = key_states[..., self.rotary_ndims :]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_rot, key_rot, cos, sin, position_ids
)
# [batch_size, num_heads, seq_len, head_dim]
query_states = torch.cat((query_states, query_pass), dim=-1)
key_states = torch.cat((key_states, key_pass), dim=-1)
if past_key_value is not None:
# Reuse k, v, self_attention
key_states = torch.cat((past_key_value[0], key_states), dim=2)
value_states = torch.cat((past_key_value[1], value_states), dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
softmax_scale = None
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
)
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# Upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
# Merge heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# Final linear projection
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def decoder_layer_forward(
self,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def stablelm_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@@ -101,16 +101,3 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len) max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)

View File

@@ -1,7 +1,6 @@
"""Module to load prompt strategies.""" """Module to load prompt strategies."""
import importlib import importlib
import inspect
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
@@ -17,10 +16,6 @@ def load(strategy, tokenizer, cfg, ds_cfg):
load_kwargs = {} load_kwargs = {}
if strategy == "user_defined": if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs) return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
return None return None

View File

@@ -1,6 +1,6 @@
"""Module for Alpaca prompt strategy classes""" """Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional, Tuple from typing import Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style), AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -1,92 +0,0 @@
"""
Basic completion text
"""
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""
_field: str = "text"
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length is not None:
self.max_length = max_length
@property
def supports_batched(self):
return True
@property
def field(self) -> str:
return self._field
@field.setter
def field(self, new_field: str):
self._field = new_field
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)
def tokenize_prompt(self, prompt):
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
for row in zip(*prompt.values()):
prompt_row = dict(zip(feature_names, row))
(
instruction,
_,
_,
) = self.parse_instruction_fields(prompt_row)
full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
for key, val in tokenized_full_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
return dict(res)
def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionPrompter:
"""
Prompter for completion
"""
def build_prompt(
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]
return strat

View File

@@ -24,15 +24,6 @@ def load(tokenizer, cfg):
) )
def load_v2(tokenizer, cfg):
return ContextQaV2PromptTokenizingStrategy(
ContextV2Prompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaContextPrompter(AlpacaPrompter): class AlpacaContextPrompter(AlpacaPrompter):
""" """
Customized system prompted for concise QA Customized system prompted for concise QA
@@ -59,38 +50,6 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
) )
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenization Strategy to combine in-context article with a question and answer
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
"Context: "
+ prompt["context"]
+ "\nQuestion: "
+ prompt["question"]
+ "\n",
"",
"Answer: " + prompt["answer"],
)
class ContextV2Prompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
"""
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
# pylint: disable=duplicate-code
self.turn_format = "{instruction}\n{input}"
self.turn_no_input_format = "{instruction}"
self.system_format = "{system}"
class AlpacaMissingInfoContextPromptTokenizingStrategy( class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy InstructionPromptTokenizingStrategy
): ):

View File

@@ -1,11 +1,11 @@
"""Module for Jokes prompts using sharegpt style """ """Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2 from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy( return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(), ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -1,47 +1,21 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2 from axolotl.prompters import PromptStyle, ShareGPTPrompter
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>\n",
)
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg):
conversation = ( return SimpleShareGPTPromptTokenizingStrategy(
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None ShareGPTPrompter(PromptStyle.CHAT.value),
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy( return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(), ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -50,7 +24,7 @@ def load_role(tokenizer, cfg):
def load_guanaco(tokenizer, cfg): def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy( return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(), ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -62,26 +36,8 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
basic sharegpt strategy to grab conversations from the sample row basic sharegpt strategy to grab conversations from the sample row
""" """
_strict = True
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"] return prompt["conversations"]
if self.strict:
return conversations
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):

View File

@@ -2,15 +2,12 @@
import abc import abc
import copy import copy
import functools
import logging import logging
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation from transformers import PreTrainedTokenizer
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -21,8 +18,6 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
add_get_turns_to_conversation()
class InvalidDataException(Exception): class InvalidDataException(Exception):
""" """
@@ -45,47 +40,56 @@ class PromptTokenizingStrategy(abc.ABC):
self.prompter = prompter self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs self.train_on_inputs = train_on_inputs
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
# TODO: Document how they are different.
self.sequence_len = sequence_len self.sequence_len = sequence_len
self.max_length = sequence_len
@abc.abstractmethod @abc.abstractmethod
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
pass pass
@property @functools.lru_cache(maxsize=128)
def supports_batched(self): def _get_user_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False return False
def _tokenize( @functools.lru_cache(maxsize=128)
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False def _get_assistant_token(self):
) -> BatchEncoding: try:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if not prompt: if isinstance(id_or_ids, (int,)):
LOG.warning("Empty text requested for tokenization.") return id_or_ids
return empty except KeyError:
pass
return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer( result = self.tokenizer(
prompt, prompt,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.sequence_len,
padding=False, padding=False,
return_tensors=None, return_tensors=None,
) )
if len(result["input_ids"]) == 0: if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
if ( if (
result["input_ids"][-1] != self.tokenizer.eos_token_id len(result["input_ids"]) > 0
and len(result["input_ids"]) < self.max_length and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token and add_eos_token
): ):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]
@@ -121,7 +125,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"]) user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize( tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True response, strip_bos_token=True, add_eos_token=True
) )
@@ -236,6 +240,23 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
) )
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""
def tokenize_prompt(self, prompt):
full_prompt = self._build_full_prompt(prompt["text"], None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt
def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction, input, response)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
""" """
Tokenizing strategy for Reflection prompts. Tokenizing strategy for Reflection prompts.
@@ -245,7 +266,6 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
( (
instruction, instruction,
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
@@ -270,7 +290,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_full_prompt["labels"] = [
IGNORE_INDEX -100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt return tokenized_full_prompt
@@ -334,89 +354,52 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
return prompt["conversations"] return prompt["conversations"]
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default() result, current_len = tokenize_prompt_default()
conversation: Conversation = ( user_token = self._get_user_token()
self.prompter._conversation.copy() # pylint: disable=protected-access assistant_token = self._get_assistant_token()
)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try: try:
for _, part in enumerate( for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt)) self.prompter.build_prompt(self.get_conversation_thread(prompt))
): ):
if not isinstance(part, tuple): if isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}") if part[0] == "USER:":
continue part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
user, assistant = conversation.roles res = self._tokenize(
role, content = part part.strip(),
add_eos_token=False,
# Uses "in" because role contains extra characters strip_bos_token=True,
if user in role: )
role = ( if user_token:
role.replace(role_remap[0]["from"], role_remap[0]["to"]) res["input_ids"] = [user_token, *res["input_ids"]]
if role_remap # everything from this is masked out from the labels
else role labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
) elif part[0] == "ASSISTANT:":
turn = role + content # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
# this is still the user query, we should part = part[0] + part[1] if not assistant_token else part[1]
if not content.strip(): # this should be the assistent response, should end with an eos token
LOG.warning(f"user turn has empty text: {prompt}") res = self._tokenize(
res = self._tokenize( part.strip(),
turn, add_eos_token=True,
add_eos_token=False, strip_bos_token=True,
strip_bos_token=True, )
) if assistant_token:
# everything from this is masked out from the labels res["input_ids"] = [
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) assistant_token,
elif assistant in role: *res["input_ids"],
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID ]
role = ( # not masked out from labels
role.replace(role_remap[1]["from"], role_remap[1]["to"]) labels = copy.deepcopy(res["input_ids"])
if role_remap elif part[0] == "SYSTEM:":
else role part = part[1] # Ignore the system role from preamble
) # this is only ever the first part, should include the bos token and the user query
turn = role + content res = self._tokenize(
# this should be the assistant response, should end with an eos token part.strip(), add_eos_token=False, strip_bos_token=False
if not content.strip(): )
LOG.warning(f"assistant turn has empty text: {prompt}") # everything from this is masked out from the labels
res = self._tokenize( labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
turn, else:
add_eos_token=True, LOG.warning(f"unhandled role: {part[0]}")
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
elif role == "":
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(
@@ -430,6 +413,29 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
except (KeyError, AssertionError, IndexError) as err: except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
""" """

View File

@@ -1,15 +1,12 @@
"""Module containing prompters""" """Module containing prompters"""
import dataclasses
import logging import logging
from enum import Enum from enum import Enum, auto
from typing import Generator, Optional, Union from typing import Generator, List, Optional, Tuple, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
class PromptStyle(Enum): class PromptStyle(Enum):
@@ -22,13 +19,7 @@ class PromptStyle(Enum):
CHATML = "chatml" CHATML = "chatml"
class Prompter: class AlpacaPrompter:
"""
Base prompter class for all prompters
"""
class AlpacaPrompter(Prompter):
""" """
Base class for alpaca prompters Base class for alpaca prompters
""" """
@@ -63,38 +54,29 @@ class AlpacaPrompter(Prompter):
) )
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input_text:
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input_text)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_no_input_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
return res
def build_prompt( def build_prompt(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
yield self._build_result(instruction, input, output) # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
def __repr__(self) -> str: if input:
return REPR_TEMPLATE.format( res = (
full_prompt=self._build_result("{instruction}", "{input}", "{output}") self.system_format.format(system=self.system_prompt)
) if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
@@ -153,6 +135,20 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
"""
Prompter for completion
"""
def build_prompt(
self,
instruction: str,
input=None, # pylint: disable=redefined-builtin, unused-argument
output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction
class GPTeacherPrompter(AlpacaPrompter): class GPTeacherPrompter(AlpacaPrompter):
""" """
Prompter for GPTeacher Prompter for GPTeacher
@@ -165,7 +161,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
""" """
class ReflectAlpacaPrompter(Prompter): class ReflectAlpacaPrompter:
""" """
Prompter for ReflectAlpaca Prompter for ReflectAlpaca
""" """
@@ -208,14 +204,14 @@ class ReflectAlpacaPrompter(Prompter):
) )
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def _build_result( def build_prompt(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
reflection: Union[None, str] = None, reflection: Union[None, str] = None,
corrected: Union[None, str] = None, corrected: Union[None, str] = None,
): ) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
@@ -229,30 +225,54 @@ class ReflectAlpacaPrompter(Prompter):
corrected=corrected, corrected=corrected,
) )
res = f"{res}{label}" res = f"{res}{label}"
yield res
return res
def build_prompt( class SeparatorStyle(Enum):
self, """Different separator style."""
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin SINGLE = auto()
output: Union[None, str] = None, TWO = auto()
reflection: Union[None, str] = None, DOLLY = auto()
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
# pylint: disable=duplicate-code # TODO clean this 💩 up
yield self._build_result( @dataclasses.dataclass
instruction, class Conversation:
input, """A class that keeps all conversation history."""
output,
reflection, system: str
corrected, roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: Optional[str] = None
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
) )
def __repr__(self) -> str: def append_message(self, role, message):
return REPR_TEMPLATE.format( self.messages.append([role, message])
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
SHAREGPT_ASSERTION_FAILED_ROLE = ( SHAREGPT_ASSERTION_FAILED_ROLE = (
@@ -260,34 +280,35 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
) )
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods class ShareGPTPrompter: # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
""" """
role_key_human = "human" def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
role_key_model = "gpt" if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
system: str = (
system_prompt
if system_prompt
else (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
self._conversation = Conversation(
system=system,
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
def __init__( def build_prompt(self, source) -> Generator[str, None, None]:
self,
prompt_style=None, # pylint: disable=unused-argument
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
self._conversation = conversation
else:
self._conversation = get_conv_template(conversation)
else:
self._conversation = get_conv_template("vicuna_v1.1")
if role_key_human:
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
def _build_result(self, source):
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
@@ -299,14 +320,17 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
# Add the conversation system prompt if provided, otherwise use the default one # Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system": if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"]) conv.system = source[0]["value"]
source.pop(0) source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]} roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try: try:
# Apply prompt templates # Apply prompt templates
if source[0]["from"] not in roles: if (
source[0]["from"] not in roles
or roles[source[0]["from"]] != conv.roles[0]
):
# Skip the first one if it is not from human # Skip the first one if it is not from human
source = source[1:] source = source[1:]
except IndexError as err: except IndexError as err:
@@ -314,54 +338,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
raise err raise err
conv.messages = [] conv.messages = []
for _, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
if len(conv.messages) > 0 and ( assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
return conv.get_turns() for part in conv.get_prompt():
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
yield part yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter):
"""
A V2 prompter that generates prompts for the ShareGPT
"""
def __init__(
self,
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
)
class UnsupportedPrompter(Prompter):
"""
A dummy class for custom prompters
"""
def __init__(self) -> None:
pass
def __repr__(self):
return "Pre-tokenized or custom dataset types are unsupported for logging"

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os import os
import signal import signal
import sys import sys
@@ -8,15 +9,13 @@ from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
import transformers.modelcard
from accelerate.logging import get_logger # add src to the pythonpath so we don't need to pip install this
from datasets import Dataset from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -26,7 +25,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = get_logger("axolotl.train") LOG = logging.getLogger("axolotl.train")
@dataclass @dataclass
@@ -41,13 +40,13 @@ class TrainDatasetMeta:
def train( def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
): ):
# load the tokenizer first # load the tokenizer first
LOG.debug( LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
@@ -55,10 +54,7 @@ def train(
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer # Load the model and tokenizer
msg = "loading model" LOG.info("loading model and (optionally) peft_config...")
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
@@ -84,15 +80,14 @@ def train(
model.config.use_cache = False model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
if peft_config: if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir) peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
@@ -107,14 +102,13 @@ def train(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
) )
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer) if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -122,19 +116,9 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload() model = model.merge_and_unload()
@@ -146,49 +130,10 @@ def train(
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
trainer.save_model(cfg.output_dir) trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
return model, tokenizer return model, tokenizer
def pretrain_hooks(cfg, trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.pretrain_hook(cfg, trainer)
def post_train_hooks(cfg, trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.post_train_hook(cfg, trainer)

View File

@@ -1,44 +1,13 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools
import pynvml import pynvml
import torch import torch
from pynvml.nvml import NVMLError
def check_cuda_device(default_value):
"""
wraps a function and returns the default value instead of running the
wrapped function if cuda isn't available or the device is auto
:param default_value:
:return:
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)
if (
not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
return default_value
return func(*args, **kwargs)
return wrapper
return deco
@check_cuda_device(0.0)
def gpu_memory_usage(device=0): def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3 return torch.cuda.memory_allocated(device) / 1024.0**3
@check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0): def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3 usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3 reserved = torch.cuda.memory_reserved(device) / 1024.0**3
@@ -46,22 +15,22 @@ def gpu_memory_usage_all(device=0):
return usage, reserved - usage, max(0, smi - reserved) return usage, reserved - usage, max(0, smi - reserved)
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0): def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device): if isinstance(device, torch.device):
device = device.index device = device.index
if isinstance(device, str) and device.startswith("cuda:"): if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:]) device = int(device[5:])
try:
pynvml.nvmlInit() pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle) info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3 return info.used / 1024.0**3
except NVMLError:
return 0.0
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device) usage, cache, misc = gpu_memory_usage_all(device)
extras = [] extras = []
if cache > 0: if cache > 0:

View File

@@ -11,13 +11,10 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb
from datasets import load_dataset from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
GenerationConfig,
Trainer,
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
TrainerState, TrainerState,
@@ -28,7 +25,6 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
broadcast_dict,
gather_scalar_from_all_ranks, gather_scalar_from_all_ranks,
get_world_size, get_world_size,
is_distributed, is_distributed,
@@ -37,32 +33,32 @@ from axolotl.utils.distributed import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100 IGNORE_INDEX = -100
class EvalFirstStepCallback( class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
TrainerCallback """Callback to save the PEFT adapter"""
): # pylint: disable=too-few-public-methods disable=unused-argument
"""
Callback to trigger evals on the first step
"""
def on_step_end( def on_save(
self, self,
args: TrainingArguments, args: TrainingArguments,
state: TrainerState, state: TrainerState,
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
): ):
if ( checkpoint_folder = os.path.join(
args.evaluation_strategy == IntervalStrategy.STEPS args.output_dir,
and args.eval_steps < 1.0 f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
and state.global_step == 1 )
):
control.should_evaluate = True peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(
peft_model_path, save_safetensors=args.save_safetensors
)
return control return control
@@ -275,7 +271,6 @@ def bench_eval_callback_factory(trainer, tokenizer):
lambda: len(data_loader), get_world_size() lambda: len(data_loader), get_world_size()
) )
results = {}
if is_distributed() and not is_main_process(): if is_distributed() and not is_main_process():
dist.gather_object(local_bench_names, dst=0) dist.gather_object(local_bench_names, dst=0)
else: else:
@@ -321,220 +316,4 @@ def bench_eval_callback_factory(trainer, tokenizer):
)["accuracy"] )["accuracy"]
trainer.log(results) trainer.log(results)
results = broadcast_dict(results)
for key, val in results.items():
metrics[key] = val
return BenchEvalCallback return BenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
eval_table_size = self.cfg.eval_table_size
if eval_table_size <= 0:
return control
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def logits_to_tokens(logits) -> torch.Tensor:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0
for batch in tqdm(table_dataloader):
if row_index > eval_table_size:
break
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
(_, batch_logits, _) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
prompt_token_ids_list = []
pred_step_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids, logits in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
batch_logits,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
pred_step_token_ids = logits_to_tokens(
logits[start : end + 1]
)[tokens_with_loss]
pred_step_token_ids_list.append(pred_step_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
pred_step_texts = tokenizer.batch_decode(
pred_step_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
for (
prompt_text,
completion_text,
prediction_text,
pred_step_text,
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
row_index += 1
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
return control
return LogPredictionCallback
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
"""Callback to save axolotl config to wandb"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
artifact = wandb.Artifact(name="axolotl-config", type="config")
artifact.add_file(local_path=self.axolotl_config_path)
wandb.run.log_artifact(artifact)
LOG.info("Axolotl config has been saved to WandB as an artifact.")
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control

View File

@@ -119,30 +119,3 @@ class DataCollatorForSeq2Seq:
features["decoder_input_ids"] = decoder_input_ids features["decoder_input_ids"] = decoder_input_ids
return features return features
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for item in features
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -4,7 +4,6 @@ import logging
import os import os
import torch import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config
@@ -26,11 +25,9 @@ def choose_device(cfg):
return "cpu" return "cpu"
cfg.device = get_device() cfg.device = get_device()
if cfg.world_size == 1: if cfg.device_map != "auto":
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"): if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()} cfg.device_map = {"": cfg.local_rank}
else: else:
cfg.device_map = {"": cfg.device} cfg.device_map = {"": cfg.device}
@@ -49,12 +46,8 @@ def normalize_config(cfg):
cfg.batch_size = ( cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
) )
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg) choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp: if cfg.ddp:
@@ -77,66 +70,20 @@ def normalize_config(cfg):
else: else:
cfg.torch_dtype = torch.float32 cfg.torch_dtype = torch.float32
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
# figure out if the model is llama # figure out if the model is llama
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama") (hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower() or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower()) or (cfg.model_type and "llama" in cfg.model_type.lower())
) )
# figure out if the model is falcon
cfg.is_falcon_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)
cfg.is_mistral_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"mistral",
]
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower()
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg): def validate_config(cfg):
if is_torch_bf16_gpu_available():
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if cfg.max_packed_sequence_len and cfg.sample_packing: if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError( raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
@@ -150,11 +97,6 @@ def validate_config(cfg):
) )
) )
if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError( raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size" "please set only one of gradient_accumulation_steps or batch_size"
@@ -165,11 +107,6 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.", "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
) )
if cfg.eval_batch_size != cfg.micro_batch_size:
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
if cfg.load_4bit: if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated") raise ValueError("cfg.load_4bit parameter has been deprecated")
@@ -195,15 +132,9 @@ def validate_config(cfg):
if not cfg.load_in_4bit: if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
@@ -217,9 +148,6 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle": if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
if cfg.trust_remote_code: if cfg.trust_remote_code:
LOG.warning( LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -258,10 +186,6 @@ def validate_config(cfg):
LOG.warning( LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely." "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
) )
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer not cfg.optimizer or "adamw" not in cfg.optimizer
@@ -291,87 +215,6 @@ def validate_config(cfg):
"sample_packing not compatible with xformers_attention. Use flash_attention" "sample_packing not compatible with xformers_attention. Use flash_attention"
) )
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if cfg.save_steps % cfg.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
if cfg.model_type == "MixFormerSequentialForCausalLM" and cfg.adapter is not None:
LOG.warning("Use AutoModelForCausalLM for phi/MixFormer models with qLoRA")
if cfg.model_config_type == "mixformer-sequential":
if cfg.sample_packing:
if cfg.adapter is not None:
LOG.warning(
"phi/MixFormer models are not currently compatible with LoRA and sample_packing"
)
if cfg.model_type == "AutoModelForCausalLM":
raise ValueError(
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if not ds_cfg.type:
continue
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if (
cfg.sample_packing
and cfg.eval_table_size
and cfg.eval_sample_packing is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -2,8 +2,9 @@
import functools import functools
import hashlib import hashlib
import logging import logging
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Union from typing import Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -16,28 +17,29 @@ from datasets import (
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy, AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy, SummarizeTLDRPromptTokenizingStrategy,
) )
from axolotl.prompters import ( from axolotl.prompters import (
AlpacaPrompter, AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter, GPTeacherPrompter,
JeopardyPrompter, JeopardyPrompter,
MultipleChoiceConcisePrompter, MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
ShareGPTPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
@@ -47,20 +49,13 @@ from axolotl.utils.trainer import (
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
@@ -73,43 +68,37 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer cfg, train_dataset, eval_dataset
) )
if cfg.max_steps: if cfg.max_steps:
total_num_steps = min( total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
) )
LOG.info(f"Maximum number of steps set at {total_num_steps}") LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset) total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
return train_dataset, eval_dataset, total_num_steps, prompters return train_dataset, eval_dataset, total_num_steps
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) -> Tuple[DatasetDict, List[Prompter]]: ) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( md5( # nosec
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
+ "|".join( + "|".join(
sorted( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets
]
)
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
) ).encode("utf-8")
) ).hexdigest()
) )
prepared_ds_path = ( prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash Path(cfg.dataset_prepared_path) / ds_hash
@@ -117,13 +106,12 @@ def load_tokenized_prepared_datasets(
else Path(default_dataset_prepared_path) / ds_hash else Path(default_dataset_prepared_path) / ds_hash
) )
dataset = None dataset = None
prompters = []
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
dataset = load_dataset( dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
@@ -131,7 +119,7 @@ def load_tokenized_prepared_datasets(
if dataset: if dataset:
... ...
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")): elif any(prepared_ds_path.glob("*")):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path)) dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...") LOG.info("Prepared dataset loaded from disk...")
@@ -156,92 +144,44 @@ def load_tokenized_prepared_datasets(
yield dataset yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets): for d in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset( load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
streaming=True, streaming=True,
token=use_auth_token, use_auth_token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError): except FileNotFoundError:
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
ds = load_from_disk(config_dataset.path) # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds_type = get_ds_type(config_dataset) ds_type = "json"
if d.ds_type:
ds_type = d.ds_type
elif ".parquet" in d.path:
ds_type = "parquet"
elif ".arrow" in d.path:
ds_type = "arrow"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=config_dataset.name, name=d.name,
data_files=config_dataset.path, data_files=d.path,
streaming=False, streaming=False,
split=None, split=None,
) )
@@ -251,99 +191,152 @@ def load_tokenized_prepared_datasets(
) )
elif ds_from_hub: elif ds_from_hub:
ds = load_dataset( ds = load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=d.data_files,
token=use_auth_token, use_auth_token=use_auth_token,
) )
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else: else:
if isinstance(config_dataset.data_files, str): fp = hf_hub_download(
fp = hf_hub_download( repo_id=d.path,
repo_id=config_dataset.path, repo_type="dataset",
repo_type="dataset", filename=d.data_files,
filename=config_dataset.data_files, )
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
)
)
else:
raise ValueError(
"data_files must be either a string or list of strings"
)
ds = load_dataset( ds = load_dataset(
"json", "json", name=d.name, data_files=fp, streaming=False, split=None
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if config_dataset.shards: if d.shards:
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
ds = ds.shuffle(seed=seed).shard( ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = config_dataset.type d_type = d.type
if isinstance(d_type, str): if isinstance(d_type, str):
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds: if "train" in ds:
ds = ds["train"] ds = ds["train"]
elif ( if (
isinstance(ds, DatasetDict) "input_ids" in ds.features
and config_dataset.train_on_split and "attention_mask" in ds.features
and config_dataset.train_on_split in ds and "labels" in ds.features
): ):
ds = ds[config_dataset.train_on_split] # dataset is already tokenized, just drop it straight in
elif isinstance(ds, DatasetDict): datasets.append(ds)
raise ValueError( elif isinstance(d.type, DictDefault):
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
GPTeacherPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "completion":
ds_strategy = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
LOG.info("merging datasets") LOG.info("merging datasets")
dataset = concatenate_datasets(datasets) dataset = concatenate_datasets(datasets)
@@ -361,32 +354,14 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
return dataset, prompters return dataset
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Prompter]]: ) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -395,12 +370,11 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len ) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str( ds_hash = str(
md5( md5( # nosec
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
@@ -411,8 +385,8 @@ def load_prepare_datasets(
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
) ).encode("utf-8")
) ).hexdigest()
) )
prepared_ds_path = ( prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash Path(cfg.dataset_prepared_path) / ds_hash
@@ -429,7 +403,7 @@ def load_prepare_datasets(
) )
dataset = load_dataset( dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
@@ -437,7 +411,7 @@ def load_prepare_datasets(
if dataset: if dataset:
... ...
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")): elif any(prepared_ds_path.glob("*")):
LOG.info( LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..." f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
) )
@@ -451,7 +425,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset, prompters = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -493,7 +467,7 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset, prompters = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -526,16 +500,21 @@ def load_prepare_datasets(
+ "|" + "|"
+ str(cfg.seed or 42) + str(cfg.seed or 42)
) )
train_fingerprint = md5(to_hash_train) train_fingerprint = hashlib.md5(
test_fingerprint = md5(to_hash_test) to_hash_train.encode(), usedforsecurity=False
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
dataset = dataset.train_test_split( with zero_first(is_main_process()):
test_size=cfg.val_set_size, dataset = dataset.train_test_split(
shuffle=False, test_size=cfg.val_set_size,
seed=cfg.seed or 42, shuffle=False,
train_new_fingerprint=train_fingerprint, seed=cfg.seed or 42,
test_new_fingerprint=test_fingerprint, train_new_fingerprint=train_fingerprint,
) test_new_fingerprint=test_fingerprint,
)
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
@@ -543,151 +522,12 @@ def load_prepare_datasets(
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, prompters return train_dataset, eval_dataset
def get_dataset_wrapper( def encode_pretraining(tokenizer, max_tokens, examples):
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
if (
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
# dataset is already tokenized, just drop it straight in
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = dataset
elif isinstance(config_dataset.type, DictDefault):
ds_strategy = load(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
ds_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
ds_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
else:
suffix = ""
if ":load_" in config_dataset.type:
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
LOG.error(
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
)
return dataset_wrapper, dataset_prompter
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples, examples["text"],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -795,12 +635,6 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train") dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map( # TODO dynamically figure out which columns/features to remove
encode, dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
batched=True,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
)
return dataset return dataset

View File

@@ -0,0 +1,300 @@
# pylint: skip-file
import hashlib
import itertools
import logging
import math
from typing import Any, Callable, List, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result, len(a)
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length left
batch, tot_seqs = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
def hash_indices(lst: List[int]) -> str:
# Convert the list of integers to a string representation
concatenated = ",".join(map(str, lst))
# Generate the hash
sha256 = hashlib.sha256()
sha256.update(concatenated.encode())
return sha256.hexdigest()
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
dataset: Any,
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
):
# Dataset
self.dataset = dataset
self.lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
self.sampler = sampler
self.batch_size = batch_size
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_replicas = 1
self.rank = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
LOG.info(hash_indices(indices))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
# yield a no-op for cases where we don't have any data left to pack
for i in range(0, len_remaining):
yield self.collate_fn(
[
{
"input_ids": [0],
"labels": [-100],
"attention_mask": [True],
"position_ids": [0],
}
]
)
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.seq_max_length
// self.batch_size
)
- 1
)
def __len__(self):
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
# if not self.eff_total_used:
# batches, _ = self.generate_batches(set_stats=True)
# LOG.info(
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
# f"actual packing efficiency: {self.efficiency()}"
# )
return max(1, self._len_est())
def len_w_stats(self):
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return max(1, self._len_est())
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

Some files were not shown because too many files have changed in this diff Show More