Compare commits
62 Commits
flash-attn
...
scatter_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10328b3429 | ||
|
|
5bfc470d57 | ||
|
|
04168801c9 | ||
|
|
d43a79b7bf | ||
|
|
884d81331e | ||
|
|
2ea75b4160 | ||
|
|
035e680631 | ||
|
|
26fc10df01 | ||
|
|
1bc008e901 | ||
|
|
3f7ed6a784 | ||
|
|
feea977923 | ||
|
|
8df7b888ff | ||
|
|
6366b0c212 | ||
|
|
05bcc9ea56 | ||
|
|
3bd8203c35 | ||
|
|
8b12468230 | ||
|
|
0976781e15 | ||
|
|
8a82d2e0a4 | ||
|
|
4326520829 | ||
|
|
b7d8a7dc4d | ||
|
|
b0ee9ec734 | ||
|
|
0bc114d2e1 | ||
|
|
7659c001aa | ||
|
|
3fd8093717 | ||
|
|
9b6ee83a73 | ||
|
|
638c2dafb5 | ||
|
|
58b0d4b0d8 | ||
|
|
ed70a08348 | ||
|
|
0cfdb2c90c | ||
|
|
37657473c8 | ||
|
|
e0f1895408 | ||
|
|
8984bf1722 | ||
|
|
2598c9f045 | ||
|
|
decb66e170 | ||
|
|
4d09b42ee3 | ||
|
|
b5b44925ec | ||
|
|
170d4d7092 | ||
|
|
00018629e7 | ||
|
|
6b3b271925 | ||
|
|
3a5a2d2f34 | ||
|
|
6d4bbb877f | ||
|
|
0f985e12fe | ||
|
|
c1a7b3dd69 | ||
|
|
2b9687f341 | ||
|
|
2c9c88b32a | ||
|
|
5265cd6b2c | ||
|
|
5be8b555a0 | ||
|
|
0f6af36d50 | ||
|
|
3f69571943 | ||
|
|
1e3d5305d3 | ||
|
|
16482796b0 | ||
|
|
f30d062b48 | ||
|
|
269c5436ea | ||
|
|
e7eed203d8 | ||
|
|
cf002312e0 | ||
|
|
7de912e097 | ||
|
|
d75653407c | ||
|
|
c6b01e0f4a | ||
|
|
cc3cebfa70 | ||
|
|
5894f0e57e | ||
|
|
5cf226e177 | ||
|
|
2ed52bd568 |
1
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
1
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -59,6 +59,7 @@ body:
|
||||
label: Config yaml
|
||||
description: |
|
||||
Please attach the config yaml!
|
||||
render: yaml
|
||||
|
||||
- type: textarea
|
||||
id: possible-solution
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -17,6 +17,6 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -18,6 +18,7 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||
is_latest: true
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
@@ -54,6 +55,7 @@ jobs:
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
|
||||
54
.github/workflows/tests.yml
vendored
54
.github/workflows/tests.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.9", "3.10", "3.11"]
|
||||
python_version: ["3.10", "3.11"]
|
||||
timeout-minutes: 10
|
||||
|
||||
steps:
|
||||
@@ -58,8 +58,8 @@ jobs:
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
timeout-minutes: 30
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 60
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
@@ -70,43 +70,31 @@ jobs:
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.2
|
||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
||||
num_gpus: 1
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.2
|
||||
num_gpus: 1
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
images: winglian/axolotl-tests
|
||||
- name: Build Docker image
|
||||
python-version: "3.10"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
# Set up build arguments
|
||||
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
||||
CUDA="${{ matrix.cuda }}"
|
||||
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
||||
# Build the Docker image
|
||||
docker build . \
|
||||
--file ./docker/Dockerfile-tests \
|
||||
--build-arg BASE_TAG=$BASE_TAG \
|
||||
--build-arg CUDA=$CUDA \
|
||||
--build-arg GITHUB_REF=$GITHUB_REF \
|
||||
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
||||
--tag ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} \
|
||||
--no-cache
|
||||
- name: Unit Tests w docker image
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
- name: GPU Unit Tests w docker image
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
||||
- name: GPU Unit Tests monkeypatched w docker image
|
||||
run: |
|
||||
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
|
||||
- name: Prune image from docker
|
||||
if: github.ref != 'refs/heads/main'
|
||||
run: |
|
||||
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
modal run cicd.tests
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -167,3 +167,8 @@ cython_debug/
|
||||
# WandB
|
||||
# wandb creates a folder to store logs for training runs
|
||||
wandb
|
||||
|
||||
# Runs
|
||||
lora-out/*
|
||||
qlora-out/*
|
||||
mlruns/*
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[mypy]
|
||||
|
||||
plugins = pydantic.mypy
|
||||
exclude = venv
|
||||
|
||||
[mypy-alpaca_lora_4bit.*]
|
||||
|
||||
@@ -31,6 +31,7 @@ repos:
|
||||
additional_dependencies:
|
||||
[
|
||||
'types-PyYAML',
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.5
|
||||
|
||||
194
README.md
194
README.md
@@ -22,10 +22,10 @@ Features:
|
||||
- [Introduction](#axolotl)
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Environment](#environment)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
|
||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
@@ -87,15 +87,17 @@ Features:
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
|
||||
✅: supported
|
||||
❌: not supported
|
||||
❓: untested
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
||||
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||
|
||||
### For developers
|
||||
```bash
|
||||
@@ -103,9 +105,18 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
```
|
||||
|
||||
General case:
|
||||
```
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
|
||||
```
|
||||
pip3 install -e '.'
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
@@ -127,13 +138,14 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
## Installation
|
||||
## Advanced Setup
|
||||
|
||||
### Environment
|
||||
|
||||
#### Docker
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
||||
```
|
||||
|
||||
Or run on the current files for development:
|
||||
@@ -152,7 +164,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAcc
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
||||
```
|
||||
|
||||
It additionally:
|
||||
@@ -167,7 +179,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
1. Install python >=**3.10**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
@@ -187,6 +199,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### Bare Metal Cloud GPU
|
||||
@@ -200,11 +213,11 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
|
||||
1. Install python
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y python3.9
|
||||
sudo apt install -y python3.10
|
||||
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
||||
sudo update-alternatives --config python # pick 3.9 if given option
|
||||
python -V # should be 3.9
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
|
||||
sudo update-alternatives --config python # pick 3.10 if given option
|
||||
python -V # should be 3.10
|
||||
|
||||
```
|
||||
|
||||
@@ -242,15 +255,18 @@ Please use WSL or Docker!
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
|
||||
```bash
|
||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||
sky check
|
||||
```
|
||||
|
||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||
```
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
cd skypilot/llm/axolotl
|
||||
```
|
||||
|
||||
Use one command to launch:
|
||||
```bash
|
||||
# On-demand
|
||||
@@ -260,32 +276,33 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||
Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
- `alpaca`: instruction; input(optional)
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
||||
```yml
|
||||
datasets:
|
||||
- path: <your-path>
|
||||
type: sharegpt
|
||||
conversation: llama-2
|
||||
```
|
||||
#### Pretraining
|
||||
|
||||
- `completion`: raw corpus
|
||||
```json
|
||||
{"text": "..."}
|
||||
```
|
||||
|
||||
Note: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset: # hf path only
|
||||
```
|
||||
|
||||
#### Supervised finetuning
|
||||
|
||||
##### Instruction
|
||||
|
||||
- `alpaca`: instruction; input(optional)
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
<details>
|
||||
|
||||
<summary>See other formats</summary>
|
||||
@@ -362,14 +379,37 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
|
||||
```
|
||||
- `pygmalion`: pygmalion
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `metharme`: instruction, adds additional eos tokens
|
||||
```json
|
||||
{"prompt": "...", "generation": "..."}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
##### Template-Free
|
||||
|
||||
- `input_output`: template-free prompt construction
|
||||
```json
|
||||
{"segments": [{"label": true|false, "text": "..."}]}
|
||||
```
|
||||
|
||||
This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details.
|
||||
|
||||
##### Conversation
|
||||
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
<details>
|
||||
|
||||
<summary>See other formats</summary>
|
||||
|
||||
- `pygmalion`: pygmalion
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
@@ -385,6 +425,8 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
</details>
|
||||
|
||||
Note: `type: sharegpt` opens a special config `conversation:` that enables conversions to many Conversation types. See dataset section under [all yaml options](#all-yaml-options).
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
@@ -406,12 +448,16 @@ datasets:
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
See full config options under [all yaml options](#all-yaml-options).
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||
|
||||
```yaml
|
||||
- path: ...
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
@@ -425,22 +471,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
- dataset
|
||||
```yaml
|
||||
sequence_len: 2048 # max token length for prompt
|
||||
|
||||
# huggingface repo
|
||||
datasets:
|
||||
# huggingface repo
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca # format from earlier
|
||||
type: alpaca
|
||||
|
||||
# huggingface repo with specific configuration/subset
|
||||
datasets:
|
||||
# huggingface repo with specific configuration/subset
|
||||
- path: EleutherAI/pile
|
||||
name: enron_emails
|
||||
type: completion # format from earlier
|
||||
field: text # Optional[str] default: text, field to use for completion data
|
||||
|
||||
# huggingface repo with multiple named configurations/subsets
|
||||
datasets:
|
||||
# huggingface repo with multiple named configurations/subsets
|
||||
- path: bigcode/commitpackft
|
||||
name:
|
||||
- ruby
|
||||
@@ -448,34 +490,29 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
datasets:
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
conversation: chatml # default: vicuna_v1.1
|
||||
|
||||
# local
|
||||
datasets:
|
||||
# local
|
||||
- path: data.jsonl # or json
|
||||
ds_type: json # see other options below
|
||||
type: alpaca
|
||||
|
||||
# dataset with splits, but no train split
|
||||
dataset:
|
||||
# dataset with splits, but no train split
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
dataset:
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
|
||||
# Loading Data From a Public URL
|
||||
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||
dataset:
|
||||
# Loading Data From a Public URL
|
||||
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
|
||||
ds_type: json # this is the default, see other options below.
|
||||
```
|
||||
@@ -484,9 +521,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
```yaml
|
||||
load_in_4bit: true
|
||||
load_in_8bit: true
|
||||
|
||||
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
|
||||
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
|
||||
tf32: true # require >=ampere
|
||||
|
||||
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
||||
float16: true # use instead of fp16 when you don't want AMP
|
||||
```
|
||||
@@ -494,7 +533,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
- lora
|
||||
```yaml
|
||||
adapter: lora # qlora or leave blank for full finetune
|
||||
adapter: lora # 'qlora' or leave blank for full finetune
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
@@ -503,9 +542,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- v_proj
|
||||
```
|
||||
|
||||
<details>
|
||||
<details id="all-yaml-options">
|
||||
|
||||
<summary>All yaml options (click me)</summary>
|
||||
<summary>All yaml options (click to expand)</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
@@ -517,8 +556,8 @@ base_model_ignore_patterns:
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
revision_of_model:
|
||||
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
@@ -535,15 +574,16 @@ tokenizer_legacy:
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
is_qwen_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
is_qwen_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config:
|
||||
overrides_of_model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
@@ -560,8 +600,6 @@ bnb_config_kwargs:
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
gptq_model_v1: false # v1 or v2
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
@@ -635,7 +673,7 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: dpo, ipo, kto_pair
|
||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
||||
rl:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
@@ -655,7 +693,7 @@ dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# Only needed if cached dataset is taking too much storage
|
||||
dataset_keep_in_memory:
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path to push finetuned model
|
||||
hub_model_id: # private repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
@@ -751,6 +789,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
||||
# mlflow configuration if you're using it
|
||||
mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
@@ -819,10 +858,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
|
||||
# For log_sweep optim
|
||||
log_sweep_min_lr:
|
||||
log_sweep_max_lr:
|
||||
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
@@ -1045,6 +1080,10 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
##### FSDP + QLoRA
|
||||
|
||||
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
@@ -1106,7 +1145,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
|
||||
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
|
||||
|
||||
```bash
|
||||
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
|
||||
@@ -1167,7 +1206,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
|
||||
|
||||
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
|
||||
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
|
||||
|
||||
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||
@@ -1214,11 +1253,20 @@ PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
pytest tests/
|
||||
|
||||
# optional: run against all files
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||
@@ -1255,4 +1303,6 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
|
||||
|
||||
#### 🥉 Bronze Sponsors - $500/mo
|
||||
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai)
|
||||
|
||||
---
|
||||
|
||||
39
cicd/Dockerfile.jinja
Normal file
39
cicd/Dockerfile.jinja
Normal file
@@ -0,0 +1,39 @@
|
||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
ENV BNB_CUDA_VERSION="{{ CUDA }}"
|
||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install pytest
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
5
cicd/cicd.sh
Executable file
5
cicd/cicd.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest /workspace/axolotl/tests/e2e/patched/
|
||||
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
||||
75
cicd/tests.py
Normal file
75
cicd/tests.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import Image, Stub
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
|
||||
"CUDA": os.environ.get("CUDA", "118"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
.env(df_args)
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=45 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
)
|
||||
def cicd_pytest():
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
@@ -16,6 +16,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -3,9 +3,10 @@ FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -7,8 +7,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG GITHUB_REF="main"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -74,7 +74,6 @@ pip3 install -e '.[flash-attn,deepspeed]'
|
||||
|
||||
If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).
|
||||
|
||||
```bash
|
||||
|
||||
### Configuration
|
||||
|
||||
|
||||
37
docs/fsdp_qlora.md
Normal file
37
docs/fsdp_qlora.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# FDSP + QLoRA
|
||||
|
||||
## Background
|
||||
|
||||
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
|
||||
|
||||
Below, we describe how to use this feature in Axolotl.
|
||||
|
||||
## Usage
|
||||
|
||||
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
|
||||
> ![Tip]
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
## References
|
||||
|
||||
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
||||
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
||||
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
||||
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
||||
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
|
||||
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
|
||||
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
|
||||
|
||||
|
||||
|
||||
|
||||
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.
|
||||
260
docs/input_output.md
Normal file
260
docs/input_output.md
Normal file
@@ -0,0 +1,260 @@
|
||||
# Template-free prompt construction with the `input_output` format
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [Background](#background)
|
||||
- [Masking Inputs](#masking-inputs)
|
||||
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
|
||||
- [The `input_output` format](#the-input_output-format)
|
||||
- [Usage](#usage)
|
||||
- [1. Prepare Data](#1-prepare-data)
|
||||
- [2. Use `type: input_output`](#2-use-type-input_output)
|
||||
- [3. Check the prompts](#3-check-the-prompts)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
<a id="markdown-background" name="background"></a>
|
||||
|
||||
## Background
|
||||
|
||||
<a id="markdown-masking-inputs" name="masking-inputs"></a>
|
||||
|
||||
### Masking Inputs
|
||||
|
||||
One of the most popular features of
|
||||
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
|
||||
setting the following configuration value:
|
||||
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false
|
||||
```
|
||||
|
||||
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
|
||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||
labels so that your model can focus on predicting the outputs only.
|
||||
|
||||
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
|
||||
|
||||
### You may not want prompt templates
|
||||
|
||||
However, there are many situations where you don't want to use one of
|
||||
these formats or templates (I usually don't!). This is because they can:
|
||||
|
||||
- Add unnecessary boilerplate to your prompts.
|
||||
- Create artifacts like special delimiters `<|im_start|>` that can
|
||||
quickly become footguns if you don't include them correctly at
|
||||
inference time.
|
||||
- Enforce a *chat* interface when you do not want one. Sometimes you
|
||||
just want to fine-tune a model to a very specific task and do NOT
|
||||
want multi-turn conversations, roles, etc.
|
||||
- Limit you to only certain roles that the template allows.
|
||||
|
||||
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
|
||||
|
||||
### The `input_output` format
|
||||
|
||||
You can construct your prompts without a template by using the
|
||||
`input_output` format, by setting `type: input_output` in your
|
||||
configuration file like this:
|
||||
|
||||
**config.yml**
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false # Mask segments of your data
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output # use template free prompt construction
|
||||
```
|
||||
|
||||
Unlike `type: completion`, which is also template-free,
|
||||
`type: input_output` allows you to mask segments of your text. More
|
||||
details on how this works are described below.
|
||||
|
||||
<a id="markdown-usage" name="usage"></a>
|
||||
|
||||
## Usage
|
||||
|
||||
This is how you can use the `input_output` format:
|
||||
|
||||
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
|
||||
|
||||
### 1. Prepare Data
|
||||
|
||||
To use the `input_output` format, collect your data in the following
|
||||
format into a jsonl file (below is the first row from the file
|
||||
`output`.jsonl` pretty printed):
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
|
||||
{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Set `label:false` when you want to mask a segment of text so that the
|
||||
model isn't trained on it. Some things to keep in mind:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
|
||||
concatenates all the segments as-is.** The tokenizer doesn't add
|
||||
anything additional. Notice how I added spaces, newlines, `<s>`
|
||||
(BOS), and `</s>` (EOS) myself.
|
||||
> 2. Make sure you check the materialized output to validate that the
|
||||
prompt is getting assembled how you like.
|
||||
|
||||
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
|
||||
|
||||
### 2. Use `type: input_output`
|
||||
|
||||
Let's materialize data with our `output.jsonl` file by setting
|
||||
`type: input_output` in our axolotl config:
|
||||
|
||||
```yaml
|
||||
# training_config.yaml
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
data_seed: 49
|
||||
seed: 49
|
||||
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output
|
||||
val_set_size: 0.1
|
||||
|
||||
sequence_len: 896
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 3
|
||||
eval_batch_size: 2
|
||||
num_epochs: 1
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
```
|
||||
|
||||
You can use the following command to materialize your data. The
|
||||
`--debug` flag will print the tokens, along with the labels so you can
|
||||
verify that the correct items are being ignored:
|
||||
|
||||
```bash
|
||||
$ python -m axolotl.cli.preprocess training_config.yaml --debug
|
||||
|
||||
...
|
||||
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
|
||||
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
|
||||
|
||||
```
|
||||
|
||||
The format is `decoded_token`(`label`, `token_id`), for example,
|
||||
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
|
||||
token_id is `1`. When the label is `-100` then that token is ignored for
|
||||
training.
|
||||
|
||||
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
|
||||
|
||||
### 3. Check the prompts
|
||||
|
||||
Here is another way to check the materialized output:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_from_disk
|
||||
import yaml
|
||||
|
||||
directory = !ls last_run_prepared/
|
||||
with open('training_config.yaml', 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
model_id = cfg['base_model']
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
||||
```
|
||||
|
||||
```python
|
||||
>>> row = ds[0]
|
||||
>>> print(tok.decode(row['input_ids']))
|
||||
<s> Hello
|
||||
hi there!. goodbye farewell</s>
|
||||
```
|
||||
|
||||
We can check that the right tokens are ingored by comparing the labels
|
||||
to each token:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
|
||||
zip(row['input_ids'], row['labels'])])
|
||||
```
|
||||
|
||||
| token | label | id |
|
||||
|-------|-------|-------|
|
||||
| 0 | \<s\> | 1 |
|
||||
| 1 | Hello | 22557 |
|
||||
| 2 | \\n | 13 |
|
||||
| 3 | hi | 12014 |
|
||||
| 4 | there | 736 |
|
||||
| 5 | ! | 28808 |
|
||||
| 6 | . | 28723 |
|
||||
| 7 | | 28705 |
|
||||
| 8 | good | -100 |
|
||||
| 9 | bye | -100 |
|
||||
| 10 | | -100 |
|
||||
| 11 | fare | 19111 |
|
||||
| 12 | well | 5458 |
|
||||
| 13 | \</s\>| 2 |
|
||||
|
||||
|
||||
|
||||
If we look at the input data, the above table seems correct! (The jsonl
|
||||
version is repeated below for reference):
|
||||
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
|
||||
{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
18
docs/mac.md
Normal file
18
docs/mac.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Mac M series support
|
||||
|
||||
Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
|
||||
|
||||
Current support:
|
||||
- [x] Support for all models
|
||||
- [x] Full training of models
|
||||
- [x] LoRA training
|
||||
- [x] Sample packing
|
||||
- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
|
||||
- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
|
||||
- [ ] xformers
|
||||
- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
|
||||
- [ ] qlora
|
||||
- [ ] DeepSpeed
|
||||
|
||||
Untested:
|
||||
- FSDP
|
||||
@@ -22,7 +22,7 @@ lora_target_linear: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
70
examples/llama-2/qlora-fsdp.yml
Normal file
70
examples/llama-2/qlora-fsdp.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
special_tokens:
|
||||
79
examples/mistral/lora-mps.yml
Normal file
79
examples/mistral/lora-mps.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./lora-out
|
||||
eval_sample_packing: false
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
sdp_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
74
examples/mistral/mixtral-qlora-fsdp.yml
Normal file
74
examples/mistral/mixtral-qlora-fsdp.yml
Normal file
@@ -0,0 +1,74 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./qlora-out
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||
special_tokens:
|
||||
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
|
||||
|
||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||
unfrozen_parameters:
|
||||
# - lm_head.*
|
||||
# - model.embed_tokens.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||
# - ^lm_head.weight$
|
||||
# - ^model.embed_tokens.weight$[:32000]
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
75
examples/mistral/mixtral_fused.py
Normal file
75
examples/mistral/mixtral_fused.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import gc
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
from transformers import AutoTokenizer, TextStreamer
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
|
||||
|
||||
def compute_memory_used_pct(device):
|
||||
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
||||
memory_pct = (
|
||||
memory_used
|
||||
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
|
||||
* 100
|
||||
)
|
||||
return memory_pct
|
||||
|
||||
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
# Load model
|
||||
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
|
||||
model = MixtralForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
device_memory_pct = compute_memory_used_pct(device_index)
|
||||
print(device_index, device_memory_pct)
|
||||
|
||||
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||
for i, (name, module) in enumerate(pbar):
|
||||
smoe = SparseMoeBlock(
|
||||
experts=module.experts,
|
||||
gate=module.gate,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
)
|
||||
old_module = model.model.layers[i].block_sparse_moe
|
||||
setattr(model.model.layers[i], "block_sparse_moe", smoe)
|
||||
del old_module
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
device_memory_pct = compute_memory_used_pct(device_index)
|
||||
print(device_index, device_memory_pct)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# Convert prompt to tokens
|
||||
prompt_template = "[INST] {prompt} [/INST]"
|
||||
|
||||
prompt = "You're standing on the surface of the Earth. "\
|
||||
"You walk one mile south, one mile west and one mile north. "\
|
||||
"You end up exactly where you started. Where are you?"
|
||||
|
||||
tokens = tokenizer(
|
||||
prompt_template.format(prompt=prompt),
|
||||
return_tensors='pt'
|
||||
).input_ids.cuda()
|
||||
|
||||
# Generate output
|
||||
generation_output = model.generate(
|
||||
tokens,
|
||||
streamer=streamer,
|
||||
max_new_tokens=512
|
||||
)
|
||||
69
examples/stablelm-2/1.6b/fft.yml
Normal file
69
examples/stablelm-2/1.6b/fft.yml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: stabilityai/stablelm-2-1_6b
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
66
examples/stablelm-2/1.6b/lora.yml
Normal file
66
examples/stablelm-2/1.6b/lora.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
base_model: stabilityai/stablelm-2-1_6b
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
36
examples/stablelm-2/README.md
Normal file
36
examples/stablelm-2/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# StableLM 2
|
||||
|
||||
This repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.
|
||||
|
||||
## Estimating GPU Requirements
|
||||
|
||||
| type | deepspeed | batch size | context length | vRAM GPU (GBs) |
|
||||
|---------------|-----------|------------|----------------|----------------|
|
||||
| full finetune | N/A | 1 | 4096 | ~21.5GBs |
|
||||
| full finetune | zero2 | 1 | 4096 | ~20GBs |
|
||||
| lora | N/A | 1 | 4096 | ~16.6GBs |
|
||||
|
||||
The above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).
|
||||
|
||||
This blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html
|
||||
|
||||
## Training
|
||||
We have example scripts here for both full finetuning and lora using the popular alpaca dataset:
|
||||
|
||||
```shell
|
||||
# preprocess the dataset
|
||||
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml
|
||||
```
|
||||
|
||||
Single GPU Training:
|
||||
```shell
|
||||
python -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json
|
||||
# OR
|
||||
python -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml
|
||||
```
|
||||
|
||||
Multinode GPU Training with `accelerate`:
|
||||
```shell
|
||||
# make sure you've configured accelerate properly
|
||||
accelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json
|
||||
```
|
||||
69
examples/starcoder2/qlora.yml
Normal file
69
examples/starcoder2/qlora.yml
Normal file
@@ -0,0 +1,69 @@
|
||||
base_model: bigcode/starcoder2-3b
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.2
|
||||
output_dir: ./qlora
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
eval_steps:
|
||||
eval_table_size:
|
||||
saves_per_epoch: 4
|
||||
save_steps:
|
||||
save_total_limit: 2
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay:
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -15,6 +15,7 @@ output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
|
||||
peft==0.9.0
|
||||
transformers==4.38.2
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
bitsandbytes>=0.43.0
|
||||
accelerate==0.26.1
|
||||
deepspeed>=0.13.1
|
||||
deepspeed==0.13.1
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
flash-attn==2.5.5
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
@@ -21,14 +22,13 @@ hf_transfer
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4
|
||||
mlflow
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.34
|
||||
fschat==0.2.36
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
@@ -40,3 +40,4 @@ gcsfs
|
||||
# adlfs
|
||||
|
||||
trl>=0.7.9
|
||||
fastcore>=1.5.29
|
||||
|
||||
11
setup.py
11
setup.py
@@ -18,6 +18,7 @@ def parse_requirements():
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
@@ -67,13 +68,13 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.5.0",
|
||||
"flash-attn==2.5.5",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed>=0.13.1",
|
||||
"deepspeed==0.13.1",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
@@ -82,5 +83,11 @@ setup(
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
],
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"lion-pytorch": [
|
||||
"lion-pytorch==0.1.2",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
@@ -24,6 +23,7 @@ from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -214,6 +214,8 @@ def do_inference_gradio(
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
import gradio as gr
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -341,7 +342,22 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
cfg.axolotl_config_path = config
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
except: # pylint: disable=bare-except # noqa: E722
|
||||
gpu_version = None
|
||||
|
||||
cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
|
||||
0
src/axolotl/core/policies/__init__.py
Normal file
0
src/axolotl/core/policies/__init__.py
Normal file
55
src/axolotl/core/policies/auto_wrap.py
Normal file
55
src/axolotl/core/policies/auto_wrap.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""module for building the auto wrap policy for FSDP"""
|
||||
import functools
|
||||
|
||||
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
_or_policy,
|
||||
lambda_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
||||
|
||||
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
]
|
||||
|
||||
|
||||
def get_wrapping_policy_factory(model_type):
|
||||
if model_type == "llama":
|
||||
layer_to_wrap = LlamaDecoderLayer
|
||||
elif model_type == "mistral":
|
||||
layer_to_wrap = MistralDecoderLayer
|
||||
elif model_type == "mixtral":
|
||||
layer_to_wrap = MixtralDecoderLayer
|
||||
|
||||
def get_wrapping_policy():
|
||||
"""This checks for lora layers (has weight and requires_grad)"""
|
||||
|
||||
def lambda_policy_fn(module):
|
||||
return (
|
||||
len(list(module.named_children())) == 0
|
||||
and getattr(module, "weight", None) is not None
|
||||
and module.weight.requires_grad
|
||||
)
|
||||
|
||||
lambda_policy = functools.partial(
|
||||
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
|
||||
)
|
||||
transformer_layer_name = layer_to_wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls=(
|
||||
PrefixEncoder,
|
||||
PromptEncoder,
|
||||
PromptEmbedding,
|
||||
transformer_layer_name,
|
||||
),
|
||||
)
|
||||
policies = [lambda_policy, transformer_wrap_policy]
|
||||
return functools.partial(_or_policy, policies=policies)
|
||||
|
||||
return get_wrapping_policy
|
||||
@@ -5,8 +5,10 @@ Builder for the training args and trainer
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,7 +18,10 @@ from typing import List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
from accelerate.utils import str_to_bool
|
||||
from datasets import Dataset
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -26,15 +31,17 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
@@ -54,6 +61,9 @@ from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
except ImportError:
|
||||
@@ -62,6 +72,10 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
@@ -175,6 +189,17 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -199,6 +224,33 @@ class AxolotlTrainer(Trainer):
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
loraplus_lr_ratio,
|
||||
loraplus_lr_embedding,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
@@ -425,6 +477,56 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.args.qlora is False:
|
||||
return res
|
||||
|
||||
# the rest of this method override is specific to fsdp + qlora (for now)
|
||||
sync_module_states = (
|
||||
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
|
||||
)
|
||||
|
||||
mp_policy = None
|
||||
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
|
||||
if amp == "fp16":
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=torch.float32,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
)
|
||||
elif amp == "bf16":
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=torch.float32,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
)
|
||||
|
||||
# If somehow we figure out how we want to parameterize we want to autocast buffers...
|
||||
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
|
||||
# load_param_skip_names = ['inv_freq']
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrapping_policy(),
|
||||
cpu_offload=False,
|
||||
use_orig_params=False,
|
||||
limit_all_gathers=True,
|
||||
param_init_fn=lambda module: module.to_empty(
|
||||
device=torch.device("cuda"), recurse=False
|
||||
)
|
||||
if (rank != 0 and sync_module_states)
|
||||
else None,
|
||||
mixed_precision_policy=mp_policy,
|
||||
)
|
||||
self.accelerator.state.fsdp_plugin = fsdp_plugin
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -648,7 +750,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow:
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
@@ -740,6 +846,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
@@ -907,6 +1016,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
@@ -962,18 +1075,42 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"neftune_noise_alpha"
|
||||
] = self.cfg.neftune_noise_alpha
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "lion_pytorch":
|
||||
from lion_pytorch import Lion
|
||||
|
||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
||||
if "weight_decay" in training_arguments_kwargs:
|
||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
||||
|
||||
if (
|
||||
"adam_beta1" in training_arguments_kwargs
|
||||
and "adam_beta2" in training_arguments_kwargs
|
||||
):
|
||||
lion_kwargs["betas"] = (
|
||||
training_arguments_kwargs["adam_beta1"],
|
||||
training_arguments_kwargs["adam_beta2"],
|
||||
)
|
||||
|
||||
trainer_kwargs["optimizers"] = (
|
||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
||||
None,
|
||||
)
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
|
||||
@@ -30,6 +30,7 @@ class ColorfulFormatter(Formatter):
|
||||
|
||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
|
||||
133
src/axolotl/loraplus.py
Normal file
133
src/axolotl/loraplus.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Module for LoRA+"""
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2024 nikhil-ghosh-berkeley
|
||||
# https://github.com/nikhil-ghosh-berkeley/loraplus
|
||||
|
||||
import logging
|
||||
from functools import reduce
|
||||
|
||||
from peft.tuners import lora
|
||||
from torch import nn
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
LOG = logging.getLogger("axolotl.loraplus")
|
||||
|
||||
|
||||
def get_module(name, opt_model):
|
||||
"""
|
||||
Retrieve a module from a model using its parameter name.
|
||||
Args:
|
||||
name (str): Full name of the parameter, typically including module path.
|
||||
opt_model (torch.nn.Module): The model from which to retrieve the module.
|
||||
|
||||
Returns:
|
||||
Module corresponding to the given name.
|
||||
"""
|
||||
parent_idx = 2 if "lora" in name else 1
|
||||
module_names = name.split(sep=".")[:-parent_idx]
|
||||
module = reduce(getattr, module_names, opt_model)
|
||||
return module
|
||||
|
||||
|
||||
def create_loraplus_optimizer(
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=None,
|
||||
):
|
||||
"""
|
||||
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
|
||||
|
||||
Args:
|
||||
opt_model (torch.nn.Module): The model for which the optimizer is being created.
|
||||
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
|
||||
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
|
||||
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
|
||||
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
|
||||
|
||||
Returns:
|
||||
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
|
||||
"""
|
||||
|
||||
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
|
||||
|
||||
if loraplus_lr_embedding is None:
|
||||
loraplus_lr_embedding = 1e-6
|
||||
|
||||
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
param_groups = {
|
||||
"groupA": {},
|
||||
"groupB": {},
|
||||
"groupB_no_decay": {},
|
||||
"embedding": {},
|
||||
}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
module = get_module(name, opt_model)
|
||||
if isinstance(module, lora.Embedding):
|
||||
param_groups["embedding"][name] = param
|
||||
elif "lora_B" in name or param.ndim == 1:
|
||||
if name in decay_parameters:
|
||||
param_groups["groupB"][name] = param
|
||||
else:
|
||||
param_groups["groupB_no_decay"][name] = param
|
||||
else:
|
||||
param_groups["groupA"][name] = param
|
||||
|
||||
assigned_param_groups = ""
|
||||
for group, group_params in param_groups.items():
|
||||
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
|
||||
LOG.info(assigned_param_groups)
|
||||
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": list(param_groups["groupA"].values()),
|
||||
"weight_decay": weight_decay,
|
||||
"lr": lr,
|
||||
},
|
||||
{
|
||||
"params": list(param_groups["embedding"].values()),
|
||||
"weight_decay": weight_decay,
|
||||
"lr": loraplus_lr_embedding,
|
||||
},
|
||||
{
|
||||
"params": list(param_groups["groupB"].values()),
|
||||
"weight_decay": weight_decay,
|
||||
"lr": lr * loraplus_lr_ratio,
|
||||
},
|
||||
{
|
||||
"params": list(param_groups["groupB_no_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr * loraplus_lr_ratio,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
return optimizer
|
||||
@@ -106,7 +106,7 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
if self.system_message:
|
||||
contains_sys_msg = True
|
||||
if self.messages:
|
||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt = self.system_template.format(
|
||||
|
||||
@@ -44,6 +44,18 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def is_xformers_swiglu_available() -> bool:
|
||||
from xformers.ops.common import get_xformers_operator
|
||||
|
||||
try:
|
||||
get_xformers_operator("swiglu_packedw")()
|
||||
return True
|
||||
except RuntimeError as exc:
|
||||
if "No such operator xformers::swiglu_packedw " in str(exc):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
|
||||
0
src/axolotl/monkeypatch/moe/__init__.py
Normal file
0
src/axolotl/monkeypatch/moe/__init__.py
Normal file
149
src/axolotl/monkeypatch/moe/linear.py
Normal file
149
src/axolotl/monkeypatch/moe/linear.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from axolotl.monkeypatch.moe import ops
|
||||
|
||||
class ParallelLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, x, expert_weights, k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates=None, grouped_in=False, grouped_out=False,
|
||||
):
|
||||
|
||||
output = ops.scatter2scatter(
|
||||
X=x, W=expert_weights,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
padded_block_idxs=padded_block_idxs,
|
||||
k=k, x_grouped=grouped_in, y_grouped=grouped_out
|
||||
)
|
||||
if gates is not None:
|
||||
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
|
||||
output = torch.bmm(
|
||||
gates[:, None, :],
|
||||
output_expanded
|
||||
).squeeze(1)
|
||||
else:
|
||||
output_expanded = None
|
||||
|
||||
ctx.save_for_backward(
|
||||
x, expert_weights,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates,
|
||||
output_expanded
|
||||
)
|
||||
ctx.grouped_in = grouped_in
|
||||
ctx.grouped_out = grouped_out
|
||||
ctx.k = k
|
||||
return output
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(x, expert_weights,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates, output_expanded) = ctx.saved_tensors
|
||||
k = ctx.k
|
||||
grouped_in = ctx.grouped_in
|
||||
grouped_out = ctx.grouped_out
|
||||
# print("backward")
|
||||
if gates is not None:
|
||||
# calculate gates gradient
|
||||
d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
|
||||
gates_flat = gates.flatten()
|
||||
gate_fan = gates.size(1)
|
||||
# print("expanded and grouping")
|
||||
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
|
||||
else:
|
||||
d_gates = None
|
||||
gates_flat = None
|
||||
gate_fan = 1
|
||||
grouped_grad_out = None
|
||||
|
||||
if grouped_out:
|
||||
grouped_grad_out = grad_out
|
||||
else:
|
||||
grouped_grad_out = ops.group(grad_out, sorted_scattered_idxs,
|
||||
fan_out=gate_fan, coeff=gates_flat,
|
||||
out=grouped_grad_out)
|
||||
if grouped_in:
|
||||
grouped_x = x
|
||||
d_expanded_input = None
|
||||
else:
|
||||
grouped_x = ops.group(x, sorted_scattered_idxs, fan_out=k)
|
||||
d_expanded_input = grouped_x
|
||||
d_weights = ops.group_bwd_W(
|
||||
DY=grouped_grad_out, X=grouped_x,
|
||||
expert_offsets=expert_offsets,
|
||||
E=expert_weights.size(0)
|
||||
)
|
||||
d_expanded_input = ops.scatter2scatter(
|
||||
X=grouped_grad_out, x_grouped=True,
|
||||
W=expert_weights.permute(0, 2, 1),
|
||||
padded_block_idxs=padded_block_idxs,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
y_grouped=grouped_in,
|
||||
out=d_expanded_input # Reuse grouped_x buffer
|
||||
)
|
||||
|
||||
if k == 1:
|
||||
d_input = d_expanded_input
|
||||
else:
|
||||
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
|
||||
# print("backward end.")
|
||||
return (
|
||||
# x, expert_weights, k,
|
||||
d_input, d_weights, None,
|
||||
# sorted_expert_idxs, sorted_scattered_idxs,
|
||||
None, None,
|
||||
# padded_block_idxs, expert_offsets,
|
||||
None, None,
|
||||
# gates
|
||||
d_gates, None, None
|
||||
)
|
||||
|
||||
def parallel_linear(inputs, expert_weights, k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates=None):
|
||||
results = ParallelLinear.apply(inputs, expert_weights, k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets, gates)
|
||||
return results
|
||||
|
||||
class ParallelExperts(nn.Module):
|
||||
def __init__(self, num_experts, input_size, output_size, device) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(num_experts, output_size, input_size, device=device)
|
||||
)
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
|
||||
def extra_repr(self):
|
||||
return 'num_experts={}, input_size={}, output_size={}'.format(
|
||||
self.num_experts, self.input_size, self.output_size)
|
||||
|
||||
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates=None, grouped_in=False, grouped_out=False):
|
||||
|
||||
results = ParallelLinear.apply(
|
||||
inputs, self.weight.permute(0, 2, 1), k,
|
||||
sorted_expert_idxs, sorted_scattered_idxs,
|
||||
padded_block_idxs, expert_offsets,
|
||||
gates, grouped_in, grouped_out
|
||||
)
|
||||
return results
|
||||
86
src/axolotl/monkeypatch/moe/mlp.py
Normal file
86
src/axolotl/monkeypatch/moe/mlp.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import gc
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.monkeypatch.moe import ops
|
||||
from axolotl.monkeypatch.moe.linear import ParallelExperts
|
||||
|
||||
|
||||
class FusedExperts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
experts: nn.ModuleList =None,
|
||||
hidden_dim=128,
|
||||
ffn_dim=512,
|
||||
num_experts=8,
|
||||
top_k=2,
|
||||
activation=nn.SiLU(),
|
||||
):
|
||||
"""
|
||||
This implements fused experts that are compatible with Mixtral.
|
||||
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
||||
"""
|
||||
super(FusedExperts, self).__init__()
|
||||
|
||||
device = experts[0].w1.weight.device
|
||||
self.num_experts = num_experts
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
|
||||
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
|
||||
self.top_k = min(top_k, self.num_experts)
|
||||
self.activation = activation
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(experts)):
|
||||
self.experts.weight.data[i].copy_(
|
||||
torch.cat(
|
||||
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
|
||||
dim=0
|
||||
)
|
||||
)
|
||||
self.output_experts.weight.data[i].copy_(
|
||||
experts[i].w2.weight.detach()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||
):
|
||||
x_shape = x.size()
|
||||
x = x.view(-1, x_shape[-1])
|
||||
with torch.no_grad():
|
||||
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
|
||||
selected_experts
|
||||
)
|
||||
padded_block_idxs, expert_offsets = ops.padded_block_indices(
|
||||
sorted_expert_idxs, self.num_experts
|
||||
)
|
||||
|
||||
h, gates = self.experts(
|
||||
x,
|
||||
self.top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs,
|
||||
expert_offsets,
|
||||
grouped_out=True,
|
||||
).chunk(2, dim=-1)
|
||||
h = self.activation(gates) * h
|
||||
y = self.output_experts(
|
||||
h,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
padded_block_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=True,
|
||||
gates=routing_weights,
|
||||
)
|
||||
y = y.view(*x_shape[:-1], y.size(-1))
|
||||
return y
|
||||
50
src/axolotl/monkeypatch/moe/moe.py
Normal file
50
src/axolotl/monkeypatch/moe/moe.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
|
||||
class SparseMoeBlock(nn.Module):
|
||||
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.gate = gate
|
||||
self.experts = FusedExperts(
|
||||
experts=experts,
|
||||
hidden_dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
activation=experts[0].act_fn
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
# get original weights back: reverse the concat + stack in the fused experts
|
||||
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
||||
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
||||
|
||||
# TODO: recreate MoE class with original weights
|
||||
experts = []
|
||||
for i in range(self.num_experts):
|
||||
pass
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# Fused expert forward
|
||||
final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
353
src/axolotl/monkeypatch/moe/ops.py
Normal file
353
src/axolotl/monkeypatch/moe/ops.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.nn import functional as F
|
||||
|
||||
BLOCK_M = 128
|
||||
|
||||
@torch.jit.script
|
||||
def flatten_and_sort(expert_idxs:torch.Tensor):
|
||||
flattened_expert_idxs = expert_idxs.flatten()
|
||||
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
|
||||
return sorted_expert_idxs, sorted_scattered_idxs
|
||||
|
||||
@torch.jit.script
|
||||
def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) :
|
||||
expert_counts = torch.bincount(sorted_experts_idxs, minlength=k)
|
||||
padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1
|
||||
padded_expert_block_end = padded_block_counts.cumsum(-1)
|
||||
expert_boundaries_end = expert_counts.cumsum(-1)
|
||||
expert_boundaries_start = expert_boundaries_end - expert_counts
|
||||
padded_expert_block_start = padded_expert_block_end - padded_block_counts
|
||||
block_idxs = torch.arange(padded_expert_block_end[-1],
|
||||
dtype=sorted_experts_idxs.dtype,
|
||||
device=sorted_experts_idxs.device)
|
||||
block_mask = (
|
||||
(block_idxs[:, None] < padded_expert_block_start) |
|
||||
(block_idxs[:, None] >= padded_expert_block_end)
|
||||
)
|
||||
expanded_block_idxs = (
|
||||
N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) +
|
||||
expert_boundaries_start
|
||||
)
|
||||
expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1)
|
||||
return expanded_block_idxs, expert_boundaries_end
|
||||
|
||||
|
||||
|
||||
def _scatter2scatter_configs():
|
||||
return [
|
||||
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
|
||||
@triton.heuristics({
|
||||
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
|
||||
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _scatter2scatter(
|
||||
X_ptr, stride_xm, stride_xk,
|
||||
W_ptr, stride_we, stride_wk, stride_wn,
|
||||
Y_ptr, stride_ym, stride_yn,
|
||||
grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
OUT_M: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
x_grouped: tl.constexpr, y_grouped: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
|
||||
M_block_id = pid // N_BLOCK_COUNT
|
||||
N_block_id = pid % N_BLOCK_COUNT
|
||||
M_range = tl.arange(0, BLOCK_M)
|
||||
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)
|
||||
# M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M)
|
||||
M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)
|
||||
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
|
||||
E_idx = tl.min(E_idxs)
|
||||
E_mask = E_idxs == E_idx
|
||||
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
|
||||
if x_grouped:
|
||||
M_in_idx = M_block
|
||||
else:
|
||||
M_in_idx = M_idx // FAN_OUT
|
||||
|
||||
if y_grouped:
|
||||
M_out_idx = M_block
|
||||
else:
|
||||
M_out_idx = M_idx
|
||||
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
# N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||
# N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
|
||||
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
for K_block_id in range(0, iters):
|
||||
if NO_K_MASK:
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
|
||||
if NO_N_MASK:
|
||||
w = tl.load(W_blk_ptrs)
|
||||
else:
|
||||
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
|
||||
else:
|
||||
K_mask = (K_block_id * BLOCK_K + K_block) < K
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
|
||||
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE)
|
||||
|
||||
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
|
||||
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])
|
||||
|
||||
def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||
padded_block_idxs, x_grouped=False, y_grouped=False,
|
||||
out=None):
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
# Pre-kernel setup
|
||||
x_dim = X.size(-1)
|
||||
y_dim = W.size(-1)
|
||||
L_scattered = sorted_expert_idxs.size(0)
|
||||
if out is None:
|
||||
O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
|
||||
else:
|
||||
assert out.size(0) == L_scattered and out.size(1) == y_dim
|
||||
O = out
|
||||
|
||||
def grid(META):
|
||||
grid_num = (
|
||||
padded_block_idxs.size(0) *
|
||||
triton.cdiv(META['N'], META['BLOCK_N']),
|
||||
)
|
||||
return grid_num
|
||||
"""
|
||||
print("X", X.size(), X.stride(),
|
||||
"W", W.size(), W.stride(),
|
||||
"O", O.size(), O.stride(),
|
||||
"sorted_idxs", sorted_scattered_idxs.size(),
|
||||
"FAN_OUT", k,
|
||||
"BLOCK_M", BLOCK_M,
|
||||
"grouped", (x_grouped, y_grouped))
|
||||
"""
|
||||
_scatter2scatter[grid](
|
||||
# X_ptr, stride_xm, stride_xk,
|
||||
X, X.stride(0), X.stride(1),
|
||||
# W_ptr, stride_we, stride_wk, stride_wn,
|
||||
W, W.stride(0), W.stride(1), W.stride(2),
|
||||
# Y_ptr, stride_ym, stride_yn,
|
||||
O, O.stride(0), O.stride(1),
|
||||
grouped_idx_ptr=sorted_scattered_idxs,
|
||||
expert_idxs_ptr=sorted_expert_idxs,
|
||||
block_start_idx_ptr=padded_block_idxs,
|
||||
FAN_OUT=k,
|
||||
M=X.size(0),
|
||||
K=X.size(1),
|
||||
N=O.size(1), E=W.size(0),
|
||||
BLOCK_M=BLOCK_M,
|
||||
ACC_TYPE=tl.float32,
|
||||
OUT_M=O.size(0),
|
||||
allow_tf32=True,
|
||||
x_grouped=x_grouped, y_grouped=y_grouped,
|
||||
)
|
||||
return O
|
||||
|
||||
|
||||
def _config_XtY():
|
||||
return [
|
||||
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
def group_bwd_W(DY, X, expert_offsets, E):
|
||||
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||
DW = DWt.permute(0, 2, 1)
|
||||
def grid(META):
|
||||
grid = (
|
||||
E * triton.cdiv(META['K'], META['BLOCK_K']),
|
||||
triton.cdiv(META['N'], META['BLOCK_N']),
|
||||
)
|
||||
return grid
|
||||
_groupXtY[grid](
|
||||
# DY_ptr, stride_dym, stride_dyk,
|
||||
DY, DY.stride(0), DY.stride(1),
|
||||
# X_ptr, stride_xm, stride_xn,
|
||||
X, X.stride(0), X.stride(1),
|
||||
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||
DW, DW.stride(0), DW.stride(1), DW.stride(2),
|
||||
# expert_offsets_ptr,
|
||||
expert_offsets,
|
||||
# K: tl.constexpr, N: tl.constexpr,
|
||||
M=DY.size(0), N=DY.size(-1), K=X.size(-1),
|
||||
# ACC_TYPE: tl.constexpr,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=True
|
||||
)
|
||||
return DW
|
||||
|
||||
@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
|
||||
@triton.heuristics({
|
||||
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
|
||||
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _groupXtY(
|
||||
DY_ptr, stride_dym, stride_dyk,
|
||||
X_ptr, stride_xm, stride_xn,
|
||||
DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||
expert_offsets_ptr,
|
||||
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
num0 = tl.num_programs(0)
|
||||
num1 = tl.num_programs(1)
|
||||
pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
|
||||
|
||||
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
|
||||
E_idx = pid0 // K_BLOCK_COUNT
|
||||
K_block_id = pid0 % K_BLOCK_COUNT
|
||||
N_block_id = pid1
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
|
||||
if end_idx > start_idx:
|
||||
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
|
||||
|
||||
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
K_mask = K_block < K
|
||||
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
|
||||
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
M_idxs = M_block
|
||||
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
|
||||
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
|
||||
|
||||
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
|
||||
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
|
||||
for i in range(0, iters):
|
||||
M_mask = (i * BLOCK_M + M_block) < end_idx
|
||||
if NO_K_MASK:
|
||||
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
|
||||
else:
|
||||
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
|
||||
if NO_N_MASK:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
|
||||
else:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
|
||||
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
|
||||
xt_blk_ptrs += BLOCK_M * stride_xm
|
||||
dy_blk_ptrs += BLOCK_M * stride_dym
|
||||
|
||||
|
||||
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
|
||||
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
|
||||
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _config_grouping():
|
||||
return [
|
||||
triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
|
||||
N = sorted_expert_idxs.size(0)
|
||||
K = A.size(1)
|
||||
assert A.size(0) * fan_out == N
|
||||
if out is not None:
|
||||
Y = out
|
||||
else:
|
||||
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
|
||||
# print("grp init:", Y.size())
|
||||
def grid(META):
|
||||
grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
|
||||
return grid_num
|
||||
_group[grid](
|
||||
# A_ptr, stride_an, stride_ai,
|
||||
A, A.stride(0), A.stride(1), coeff is not None, coeff, fan_out,
|
||||
# Y_ptr, stride_yn, stride_yk,
|
||||
Y, Y.stride(0), Y.stride(1),
|
||||
# grouped_idx_ptr,
|
||||
sorted_expert_idxs,
|
||||
# N: tl.constexpr, K: tl.constexpr,
|
||||
N, K
|
||||
)
|
||||
return Y
|
||||
|
||||
@triton.autotune(configs=_config_grouping(), key=['K'])
|
||||
@triton.heuristics({
|
||||
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
|
||||
})
|
||||
@triton.jit
|
||||
def _group(
|
||||
src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
|
||||
tgt_ptr, stride_tn, stride_ti,
|
||||
grouped_idx_ptr,
|
||||
N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_block_id = pid
|
||||
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_blk < N
|
||||
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
|
||||
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
|
||||
|
||||
K_blk = tl.arange(0, BLOCK_K)
|
||||
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
|
||||
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
|
||||
|
||||
if has_coeff:
|
||||
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
|
||||
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
for i in range(0, iters):
|
||||
if NO_K_MASK:
|
||||
block = tl.load(src_blk_ptrs) # , mask=N_mask[:, None])
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
|
||||
|
||||
else:
|
||||
K_mask = (i * BLOCK_K + K_blk) < K
|
||||
mask = N_mask[:, None] & K_mask[None, :]
|
||||
block = tl.load(src_blk_ptrs, mask=mask)
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=mask)
|
||||
|
||||
src_blk_ptrs += BLOCK_K * stride_sk
|
||||
tgt_blk_ptrs += BLOCK_K * stride_ti
|
||||
66
src/axolotl/monkeypatch/moe/single.py
Normal file
66
src/axolotl/monkeypatch/moe/single.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Adapted from:
|
||||
https://github.com/shawntan/scattermoe
|
||||
https://arxiv.org/abs/2403.08245
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.nn import functional as F
|
||||
|
||||
@triton.jit
|
||||
def _single2scatter(
|
||||
X_ptr, stride_xm, stride_xk,
|
||||
W_ptr, stride_we, stride_wk, stride_wn,
|
||||
Y_ptr, stride_ym, stride_yn,
|
||||
expert_idxs_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
|
||||
N_block_id = pid0
|
||||
if FAN_OUT == 1:
|
||||
in_idx = pid1
|
||||
else:
|
||||
in_idx = 0
|
||||
out_idx = pid1
|
||||
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
|
||||
E_idx = tl.load(expert_idxs_ptr + pid1)
|
||||
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
|
||||
W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
|
||||
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
|
||||
for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
x = tl.load(X_blk_ptrs)
|
||||
w = tl.load(W_blk_ptrs)
|
||||
acc += tl.sum(x * w, axis=0)[None, :]
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
|
||||
tl.store(Y_blk_ptrs, acc)
|
||||
|
||||
def single2scatter(X, W, expert_idxs):
|
||||
E, xdim, ydim = W.size()
|
||||
k = expert_idxs.size(1)
|
||||
assert X.size(0) == k or X.size(0) == 1
|
||||
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 128
|
||||
grid = ydim // BLOCK_N, k
|
||||
_single2scatter[grid](
|
||||
X, X.stride(0), X.stride(1),
|
||||
W, W.stride(0), W.stride(1), W.stride(2),
|
||||
Y, Y.stride(0), Y.stride(1),
|
||||
expert_idxs,
|
||||
FAN_OUT=Y.size(0) // X.size(0),
|
||||
K=xdim, N=ydim, E=E,
|
||||
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
|
||||
ACC_TYPE=tl.float32
|
||||
)
|
||||
return Y
|
||||
@@ -1,15 +1,26 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
def patch_for_multipack(model_type, model_name=None):
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
@@ -32,3 +43,19 @@ def patch_for_multipack(model_type):
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_gemmoe to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_gemmoe", ".modeling_gemmoe"
|
||||
)
|
||||
modeling_gemmoe = importlib.import_module(module_name)
|
||||
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
@@ -267,7 +267,7 @@ class ReLoRAScheduler(LRScheduler):
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.relora_steps:
|
||||
if step < self.relora_steps - self.warmup_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.relora_steps
|
||||
|
||||
78
src/axolotl/prompt_strategies/chat_template.py
Normal file
78
src/axolotl/prompt_strategies/chat_template.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class ChatTemplatePrompter(Prompter):
|
||||
"""prompter for HF chat templates"""
|
||||
|
||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
tokenized_prompt = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {
|
||||
"human": "user",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
turns = [
|
||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
chat_template = (
|
||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||
)
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
@@ -8,14 +8,13 @@ import logging
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load(strategy, cfg):
|
||||
def load(strategy, cfg, **kwargs):
|
||||
try:
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
return func(cfg, **load_kwargs)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
|
||||
@@ -5,6 +5,7 @@ DPO strategies for chatml
|
||||
|
||||
def argilla(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
@@ -23,8 +24,28 @@ def argilla(
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
@@ -48,7 +69,7 @@ def icr(
|
||||
return transform_fn
|
||||
|
||||
|
||||
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
For Intel Orca DPO Pairs
|
||||
"""
|
||||
@@ -70,7 +91,9 @@ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
return transform_fn
|
||||
|
||||
|
||||
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def prompt_pairs(
|
||||
cfg, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
@@ -88,7 +111,7 @@ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argume
|
||||
return transform_fn
|
||||
|
||||
|
||||
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for ultrafeedback binarized conversations
|
||||
"""
|
||||
|
||||
41
src/axolotl/prompt_strategies/dpo/user_defined.py
Normal file
41
src/axolotl/prompt_strategies/dpo/user_defined.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
User-defined DPO strategies
|
||||
"""
|
||||
|
||||
|
||||
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]["type"]
|
||||
if not isinstance(ds_cfg, dict):
|
||||
raise ValueError(
|
||||
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
|
||||
)
|
||||
field_prompt = ds_cfg.get("field_prompt", "prompt")
|
||||
field_system = ds_cfg.get("field_system", "system")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
prompt_format = ds_cfg.get("prompt_format")
|
||||
if not prompt_format:
|
||||
prompt_format = "{" + field_prompt + "}"
|
||||
chosen_format = ds_cfg.get("chosen_format")
|
||||
if not chosen_format:
|
||||
chosen_format = "{" + field_chosen + "}"
|
||||
rejected_format = ds_cfg.get("rejected_format")
|
||||
if not rejected_format:
|
||||
rejected_format = "{" + field_rejected + "}"
|
||||
|
||||
def transform_fn(sample):
|
||||
if (
|
||||
"{" + field_system + "}" in prompt_format
|
||||
and field_system in sample
|
||||
and sample[field_system]
|
||||
):
|
||||
sample["prompt"] = prompt_format.format(
|
||||
system=sample[field_system], prompt=sample[field_prompt]
|
||||
)
|
||||
else:
|
||||
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
||||
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
||||
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -3,7 +3,7 @@ DPO strategies for zephyr
|
||||
"""
|
||||
|
||||
|
||||
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
data = {}
|
||||
data["prompt"] = (
|
||||
|
||||
54
src/axolotl/prompt_strategies/input_output.py
Normal file
54
src/axolotl/prompt_strategies/input_output.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Module for plain input/output prompt pairs"""
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
|
||||
|
||||
class RawInputOutputStrategy(PromptTokenizingStrategy):
|
||||
"""Prompt Strategy class for input/output pairs"""
|
||||
|
||||
def __init__(self, *args, eos_token=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.eos_token = eos_token
|
||||
if not eos_token:
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
input_ids = []
|
||||
labels = []
|
||||
for label, text in self.prompter.build_prompt(prompt["segments"]):
|
||||
tokenized_output = self.tokenizer(
|
||||
text, add_special_tokens=False, return_tensors=None
|
||||
)["input_ids"]
|
||||
input_ids += tokenized_output
|
||||
if label or self.train_on_inputs:
|
||||
labels += tokenized_output
|
||||
else:
|
||||
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
|
||||
|
||||
tokenized_prompt = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class RawInputOutputPrompter(Prompter):
|
||||
"""prompter for raw i/o data"""
|
||||
|
||||
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
|
||||
for segment in source:
|
||||
yield segment["label"], segment["text"]
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return RawInputOutputStrategy(
|
||||
RawInputOutputPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.utils.tokenization import (
|
||||
chatml_to_conversation,
|
||||
merge_consecutive_messages,
|
||||
)
|
||||
|
||||
|
||||
def register_chatml_template(system_message=None):
|
||||
@@ -19,6 +24,16 @@ def register_chatml_template(system_message=None):
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml_glaive",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
@@ -77,12 +92,26 @@ def load_guanaco(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else "chatml_glaive"
|
||||
)
|
||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(conversation=conversation),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = True
|
||||
_strict = False
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
@@ -96,10 +125,25 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
conversations = prompt["conversations"]
|
||||
if self.strict:
|
||||
return conversations
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
||||
role_key = "from"
|
||||
if "role" in conversations[0].keys():
|
||||
role_key = "role"
|
||||
value_key = "value"
|
||||
if "text" in conversations[0].keys():
|
||||
value_key = "text"
|
||||
elif "content" in conversations[0].keys():
|
||||
value_key = "content"
|
||||
# remap roles - allow for assistant turn"
|
||||
role_map = {
|
||||
"user": "human",
|
||||
"human": "human",
|
||||
"assistant": "gpt",
|
||||
"gpt": "gpt",
|
||||
"system": "system",
|
||||
}
|
||||
turns = [
|
||||
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
||||
{"from": role_map[t[role_key]], "value": t[value_key]}
|
||||
for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
@@ -143,3 +187,15 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
|
||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps glaive data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversation = chatml_to_conversation(prompt)
|
||||
conversation = merge_consecutive_messages(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
@@ -360,11 +360,19 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
user, assistant = conversation.roles
|
||||
tool_role_label = None
|
||||
if len(conversation.roles) == 3:
|
||||
(
|
||||
user_role_label,
|
||||
assistant_role_label,
|
||||
tool_role_label,
|
||||
) = conversation.roles
|
||||
else:
|
||||
user_role_label, assistant_role_label = conversation.roles
|
||||
role, content = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
if user in role:
|
||||
if user_role_label in role:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
@@ -384,7 +392,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif assistant in role:
|
||||
elif assistant_role_label in role:
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
@@ -426,6 +434,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif tool_role_label and tool_role_label in role:
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
@@ -267,6 +267,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
# Optional, only used for tool usage datasets.
|
||||
role_key_tool = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -274,6 +276,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
@@ -286,6 +289,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
if role_key_tool:
|
||||
self.role_key_tool = role_key_tool
|
||||
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
@@ -303,6 +308,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
if self.role_key_tool:
|
||||
roles[self.role_key_tool] = conv.roles[2]
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -99,7 +99,7 @@ def train(
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
|
||||
@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
|
||||
or not torch.cuda.is_available()
|
||||
or device == "auto"
|
||||
or torch.device(device).type == "cpu"
|
||||
or torch.device(device).type == "meta"
|
||||
):
|
||||
return default_value
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import evaluate
|
||||
import mlflow
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -42,8 +41,8 @@ from axolotl.utils.distributed import (
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
IGNORE_INDEX = -100
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class EvalFirstStepCallback(
|
||||
@@ -756,31 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to mlflow"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||
return control
|
||||
44
src/axolotl/utils/callbacks/mlflow_.py
Normal file
44
src/axolotl/utils/callbacks/mlflow_.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""MLFlow module for trainer callbacks"""
|
||||
import logging
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import mlflow
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||
# pylint: disable=duplicate-code
|
||||
"""Callback to save axolotl config to mlflow"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the MLflow artifacts."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||
return control
|
||||
@@ -22,6 +22,7 @@ def chat_templates(user_choice: str):
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
}
|
||||
|
||||
if user_choice in templates:
|
||||
|
||||
@@ -3,11 +3,16 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlConfigWCapabilities,
|
||||
AxolotlInputConfig,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
|
||||
@@ -119,7 +124,7 @@ def normalize_config(cfg):
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
# figure out if the model is falcon
|
||||
@@ -135,7 +140,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_falcon_derived_model
|
||||
or "falcon" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_mistral_derived_model = (
|
||||
@@ -148,7 +153,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_mistral_derived_model
|
||||
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
@@ -159,9 +164,6 @@ def normalize_config(cfg):
|
||||
]
|
||||
) or cfg.is_qwen_derived_model
|
||||
|
||||
if isinstance(cfg.learning_rate, str):
|
||||
cfg.learning_rate = float(cfg.learning_rate)
|
||||
|
||||
if isinstance(cfg.pretraining_dataset, dict):
|
||||
cfg.pretraining_dataset = [cfg.pretraining_dataset]
|
||||
|
||||
@@ -191,7 +193,21 @@ def normalize_cfg_datasets(cfg):
|
||||
cfg.datasets[idx].conversation = "chatml"
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
if capabilities:
|
||||
return DictDefault(
|
||||
dict(
|
||||
AxolotlConfigWCapabilities(
|
||||
**cfg.to_dict(), capabilities=capabilities
|
||||
).model_dump(exclude_unset=True)
|
||||
)
|
||||
)
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))
|
||||
)
|
||||
|
||||
|
||||
def legacy_validate_config(cfg):
|
||||
"""
|
||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||
information about the model architecture
|
||||
@@ -363,11 +379,11 @@ def validate_config(cfg):
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
@@ -480,9 +496,6 @@ def validate_config(cfg):
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
0
src/axolotl/utils/config/models/__init__.py
Normal file
0
src/axolotl/utils/config/models/__init__.py
Normal file
0
src/axolotl/utils/config/models/input/__init__.py
Normal file
0
src/axolotl/utils/config/models/input/__init__.py
Normal file
1004
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Normal file
1004
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
14
src/axolotl/utils/config/models/internals/__init__.py
Normal file
14
src/axolotl/utils/config/models/internals/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""module for gpu capabilities"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GPUCapabilities(BaseModel):
|
||||
"""model to manage the gpu capabilities statically"""
|
||||
|
||||
bf16: bool = Field(default=False)
|
||||
fp8: bool = Field(default=False)
|
||||
n_gpu: int = Field(default=1)
|
||||
n_node: int = Field(default=1)
|
||||
compute_capability: Optional[str] = Field(default=None)
|
||||
@@ -937,7 +937,9 @@ def load_prepare_dpo_datasets(cfg):
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
ds_transform_fn = load_dpo(_type, _cfg)
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
split_datasets[i] = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
|
||||
@@ -12,4 +12,4 @@ class DictDefault(Dict):
|
||||
return None
|
||||
|
||||
def __or__(self, other):
|
||||
return DictDefault(super().__or__(other))
|
||||
return DictDefault(super().__ror__(other))
|
||||
|
||||
@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||
|
||||
|
||||
def freeze_parameters_except(model, regex_patterns):
|
||||
def freeze_layers_except(model, regex_patterns):
|
||||
"""
|
||||
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||
@@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns):
|
||||
Parameters:
|
||||
- model (nn.Module): The PyTorch model to be modified.
|
||||
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
|
||||
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
|
||||
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
|
||||
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
|
||||
|
||||
Returns:
|
||||
None; the model is modified in place.
|
||||
"""
|
||||
# Escape periods and compile the regex patterns
|
||||
compiled_patterns = [
|
||||
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||
]
|
||||
if isinstance(regex_patterns, str):
|
||||
regex_patterns = [regex_patterns]
|
||||
|
||||
# First, freeze all parameters in the model
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
|
||||
|
||||
# Unfreeze layers that match the regex patterns
|
||||
for name, param in model.named_parameters():
|
||||
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||
if is_main_process():
|
||||
LOG.debug(f"unfreezing {name}")
|
||||
param.requires_grad = False
|
||||
unfrozen_ranges = []
|
||||
for pattern in patterns:
|
||||
if not pattern.match(name):
|
||||
continue
|
||||
|
||||
param.requires_grad = True
|
||||
|
||||
if pattern.range is not None:
|
||||
unfrozen_ranges.append(pattern.range)
|
||||
|
||||
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
|
||||
|
||||
if param.requires_grad and is_main_process():
|
||||
unfrozen_ranges = (
|
||||
f" with ranges {merged_unfrozen_ranges}"
|
||||
if merged_unfrozen_ranges
|
||||
else ""
|
||||
)
|
||||
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
|
||||
|
||||
if not merged_unfrozen_ranges:
|
||||
continue
|
||||
|
||||
# The range list we need is actually the inverted of the merged ranges
|
||||
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
|
||||
|
||||
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
|
||||
|
||||
if is_main_process() and all(
|
||||
not param.requires_grad for param in model.parameters()
|
||||
):
|
||||
LOG.warning("All parameters are frozen. Model will not be trained.")
|
||||
|
||||
|
||||
def _invert_ranges(
|
||||
given_ranges: List[Tuple[int, int]], layer_size: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
|
||||
|
||||
Parameters:
|
||||
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
||||
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
||||
Returns:
|
||||
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
||||
"""
|
||||
if not given_ranges:
|
||||
return [(0, layer_size)]
|
||||
|
||||
inverted_ranges = []
|
||||
current_start = 0
|
||||
|
||||
for start, end in sorted(given_ranges):
|
||||
if start > current_start:
|
||||
inverted_ranges.append((current_start, start))
|
||||
current_start = max(current_start, end)
|
||||
|
||||
# Handle the case where the last given range does not reach the end of the total_size
|
||||
if current_start < layer_size:
|
||||
inverted_ranges.append((current_start, layer_size))
|
||||
|
||||
return inverted_ranges
|
||||
|
||||
|
||||
def _merge_ranges(
|
||||
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Merges overlapping ranges and sorts the given ranges.
|
||||
|
||||
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
|
||||
as tuples, where the first element is the start index (inclusive) and the second element is the end
|
||||
index (exclusive). The end index can be None, indicating that the range extends to the end of the
|
||||
sequence.
|
||||
|
||||
Parameters:
|
||||
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
|
||||
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
||||
|
||||
Returns:
|
||||
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
|
||||
"""
|
||||
# End of each range can be determined now since we have the total size
|
||||
processed_ranges = [
|
||||
(start, end if end is not None else layer_size) for start, end in given_ranges
|
||||
]
|
||||
|
||||
# No need to merge if there's only one or no ranges
|
||||
if len(processed_ranges) <= 1:
|
||||
return processed_ranges
|
||||
|
||||
sorted_ranges = sorted(processed_ranges)
|
||||
|
||||
merged_ranges = [sorted_ranges[0]]
|
||||
for start, end in sorted_ranges[1:]:
|
||||
prev_start, prev_end = merged_ranges[-1]
|
||||
if start <= prev_end:
|
||||
merged_ranges[-1] = (prev_start, max(prev_end, end))
|
||||
else:
|
||||
merged_ranges.append((start, end))
|
||||
|
||||
return merged_ranges
|
||||
|
||||
|
||||
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
|
||||
"""
|
||||
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
|
||||
|
||||
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
|
||||
two integers representing the start and end indices of the range.
|
||||
|
||||
Parameters:
|
||||
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
|
||||
|
||||
Returns:
|
||||
- Callable: A hook function to be used with `register_hook` on parameters.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
ranges_to_freeze = [(0, 10), (20, 30)]
|
||||
hook = _create_freeze_parameters_hook(ranges_to_freeze)
|
||||
model.register_hook(hook)
|
||||
```
|
||||
"""
|
||||
|
||||
def freeze_parameters_hook(gradients):
|
||||
for start, end in ranges_to_freeze:
|
||||
gradients[start:end].zero_()
|
||||
|
||||
return freeze_parameters_hook
|
||||
|
||||
|
||||
class LayerNamePattern:
|
||||
"""
|
||||
Represents a regex pattern for layer names, potentially including a parameter index range.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: str):
|
||||
"""
|
||||
Initializes a new instance of the LayerNamePattern class.
|
||||
|
||||
Parameters:
|
||||
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
|
||||
"""
|
||||
self.raw_pattern = pattern
|
||||
name_pattern, self.range = self._parse_pattern(pattern)
|
||||
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
||||
|
||||
def match(self, name: str) -> bool:
|
||||
"""
|
||||
Checks if the given layer name matches the regex pattern.
|
||||
|
||||
Parameters:
|
||||
- name (str): The layer name to check.
|
||||
|
||||
Returns:
|
||||
- bool: True if the layer name matches the pattern, False otherwise.
|
||||
"""
|
||||
return self.name_regex.match(name) is not None
|
||||
|
||||
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
||||
"""
|
||||
Extracts the range pattern from the given pattern.
|
||||
|
||||
Parameters:
|
||||
- pattern (str): The pattern to extract the range from.
|
||||
|
||||
Returns:
|
||||
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
|
||||
"""
|
||||
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
|
||||
if not match:
|
||||
return pattern, None
|
||||
|
||||
base_pattern, start_part, end_part = match.groups()
|
||||
|
||||
if end_part is None and start_part.isdecimal():
|
||||
index = int(start_part)
|
||||
return base_pattern, (index, index + 1)
|
||||
|
||||
# [:end] or [start:] or [start:end]
|
||||
start = int(start_part) if start_part else 0
|
||||
end = int(end_part) if end_part else None
|
||||
|
||||
if end is not None and start >= end:
|
||||
raise ValueError(
|
||||
f"Invalid range in layer name pattern: {pattern}."
|
||||
"End of range must be greater than start."
|
||||
)
|
||||
return base_pattern, (start, end)
|
||||
|
||||
@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
def setup_mlflow_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("mlflow_"):
|
||||
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
"""Module for models and model loading"""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
import safetensors
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from bitsandbytes.nn import Linear4bit, Params4bit
|
||||
from fastcore.parallel import parallel
|
||||
from peft import (
|
||||
LoftQConfig,
|
||||
PeftConfig,
|
||||
@@ -16,6 +23,7 @@ from peft import (
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from torch import Tensor, nn
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||
|
||||
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -86,8 +96,8 @@ def load_model_config(cfg):
|
||||
model_config_name = cfg.tokenizer_config
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
config_kwargs = {}
|
||||
if cfg.model_revision:
|
||||
config_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.revision_of_model:
|
||||
config_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
@@ -104,8 +114,8 @@ def load_model_config(cfg):
|
||||
)
|
||||
raise err
|
||||
|
||||
if cfg.model_config:
|
||||
for key, val in cfg.model_config.items():
|
||||
if cfg.overrides_of_model_config:
|
||||
for key, val in cfg.overrides_of_model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def replace_linear(
|
||||
model: nn.Module,
|
||||
linear_replacement: Type[nn.Module],
|
||||
quant_config: Union[dict, None] = None,
|
||||
skip_modules=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
"""
|
||||
if skip_modules is None:
|
||||
skip_modules = ["lm_head"]
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(
|
||||
module, linear_replacement, quant_config, skip_modules, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
if issubclass(linear_replacement, Linear4bit):
|
||||
model._modules[ # pylint: disable=protected-access
|
||||
name
|
||||
] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported linear replacement: {type(linear_replacement)}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_and_quantize(
|
||||
module: nn.Module,
|
||||
name: str,
|
||||
value: Tensor,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
skip_names: Optional[List[str]] = None,
|
||||
is_meta_rank: bool = False,
|
||||
low_memory: bool = True,
|
||||
verbose: bool = False,
|
||||
quant_method: str = "bnb",
|
||||
):
|
||||
"""
|
||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||
|
||||
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
|
||||
"""
|
||||
|
||||
if skip_names is None:
|
||||
skip_names = []
|
||||
|
||||
def place_on_device(value):
|
||||
if is_meta_rank:
|
||||
device = "meta"
|
||||
elif low_memory:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
return value.to(device=device, dtype=dtype)
|
||||
|
||||
if any(skip_name in name for skip_name in skip_names):
|
||||
if verbose:
|
||||
print(f"Skipping {name} because it is in skip_names")
|
||||
return
|
||||
|
||||
module_key, _, value_key = name.rpartition(".")
|
||||
try:
|
||||
submodule = module.get_submodule(module_key)
|
||||
except AttributeError as exc:
|
||||
print(f"Module {module_key} not found:\n{exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
if quant_method == "bnb":
|
||||
param = submodule.get_parameter(value_key)
|
||||
if isinstance(param, Params4bit):
|
||||
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
||||
# shape as the quantized Params4bit with an initialized quant_state. However,
|
||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||
value = type(param)(
|
||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||
).cuda(device)
|
||||
if is_meta_rank:
|
||||
value = type(param)(value.data.to("meta"), **value.__dict__)
|
||||
elif low_memory:
|
||||
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
||||
else:
|
||||
value = type(param)(place_on_device(value).data)
|
||||
|
||||
except AttributeError:
|
||||
# it's a buffer
|
||||
value = place_on_device(value)
|
||||
|
||||
setattr(submodule, value_key, value)
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@@ -272,7 +393,7 @@ def load_model(
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
model_type = cfg.model_type
|
||||
model_type = cfg.type_of_model
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
@@ -308,7 +429,7 @@ def load_model(
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type)
|
||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
@@ -394,7 +515,7 @@ def load_model(
|
||||
|
||||
if max_memory is not None:
|
||||
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from accelerate import infer_auto_device_map
|
||||
|
||||
with init_empty_weights():
|
||||
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
||||
@@ -426,8 +547,8 @@ def load_model(
|
||||
if is_deepspeed_zero3_enabled():
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.revision_of_model:
|
||||
model_kwargs["revision"] = cfg.revision_of_model
|
||||
if cfg.gptq:
|
||||
if not hasattr(model_config, "quantization_config"):
|
||||
LOG.warning("model config does not contain quantization_config information")
|
||||
@@ -496,8 +617,78 @@ def load_model(
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
qlora_fsdp = (
|
||||
cfg.fsdp
|
||||
and cfg.adapter == "qlora"
|
||||
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
if qlora_fsdp:
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
|
||||
elif cfg.fp16 or cfg.float16:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
||||
else:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
||||
|
||||
with init_empty_weights():
|
||||
LOG.info("Loading model with empty weights.")
|
||||
model = AutoModelForCausalLM.from_config(model_config)
|
||||
model.model = replace_linear(
|
||||
model.model,
|
||||
Linear4bit,
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=torch_dtype,
|
||||
)
|
||||
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
# Grab the safetensors files that hold the weights
|
||||
try:
|
||||
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
|
||||
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
|
||||
except OSError:
|
||||
try:
|
||||
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
||||
files = []
|
||||
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
|
||||
except OSError as exc:
|
||||
# This means the model probably doesn't have a safetensors file
|
||||
raise exc
|
||||
|
||||
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
||||
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
||||
def load_and_quantize_parallel(name_param, model, **kwargs):
|
||||
name, param = name_param
|
||||
load_and_quantize(model, name, param, **kwargs)
|
||||
|
||||
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
||||
for filename in files:
|
||||
weights = safetensors.torch.load_file(filename)
|
||||
quant_method = "bnb"
|
||||
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
left = int(os.cpu_count() / torch.cuda.device_count())
|
||||
right = int(
|
||||
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
|
||||
)
|
||||
n_workers = min(left, right)
|
||||
parallel(
|
||||
load_and_quantize_parallel,
|
||||
weights.items(),
|
||||
n_workers=n_workers,
|
||||
threadpool=True,
|
||||
model=model,
|
||||
dtype=torch_dtype,
|
||||
device=cfg.local_rank,
|
||||
skip_names=[],
|
||||
is_meta_rank=(cfg.local_rank != 0),
|
||||
verbose=False,
|
||||
quant_method=quant_method,
|
||||
)
|
||||
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
@@ -512,43 +703,39 @@ def load_model(
|
||||
|
||||
if cfg.flash_attention and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
is_xformers_swiglu_available,
|
||||
replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_fuse_mlp:
|
||||
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
LOG.info("patching with SwiGLU")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
if cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("patching with fused QKV")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
# This is a WIP, still an issue with the backward pass
|
||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||
# TODO: try config.sequence_parallel = False
|
||||
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
|
||||
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
|
||||
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
|
||||
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
# from flash_attn.models.gpt import GPTLMHeadModel
|
||||
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
|
||||
# from transformers import GPTNeoXConfig
|
||||
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
|
||||
# config.use_flash_attn = True
|
||||
# config.fused_bias_fc = True
|
||||
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
||||
# config.activation_function = "gelu_fast"
|
||||
# config.fused_dropout_add_ln = True
|
||||
# # config.residual_in_fp32 = True
|
||||
#
|
||||
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
|
||||
# base_model,
|
||||
# config,
|
||||
# dtype=torch_dtype,
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif (
|
||||
model_config.model_type == "mixtral"
|
||||
and not cfg.adapter
|
||||
and cfg.fuse_moe
|
||||
):
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, MixtralSparseMoeBlock):
|
||||
smoe = SparseMoeBlock(
|
||||
experts=module.experts,
|
||||
gate=module.gate,
|
||||
hidden_dim=module.hidden_dim,
|
||||
ffn_dim=module.ffn_dim,
|
||||
num_experts=module.num_experts,
|
||||
top_k=module.top_k,
|
||||
)
|
||||
set_module_name(model, name, smoe)
|
||||
|
||||
elif model_type == "MambaLMHeadModel":
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
@@ -612,7 +799,7 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
embeddings_len = (
|
||||
@@ -691,6 +878,9 @@ def load_model(
|
||||
if cfg.adapter == "lora" and loftq_bits:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if qlora_fsdp:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
@@ -705,7 +895,7 @@ def load_model(
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if needs_fa2_dtype or cfg.flash_attention:
|
||||
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
@@ -723,7 +913,12 @@ def load_model(
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
||||
if (
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
and not (cfg.rl and cfg.load_in_4bit)
|
||||
and not qlora_fsdp
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
@@ -812,6 +1007,30 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def setup_quantized_meta_for_peft(model: nn.Module):
|
||||
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
|
||||
|
||||
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
return self
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit):
|
||||
param.quant_state._orig_to = ( # pylint: disable=protected-access
|
||||
param.quant_state.to
|
||||
)
|
||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||
|
||||
|
||||
def setup_quantized_peft_meta_for_training(model: nn.Module):
|
||||
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
|
||||
param.quant_state.to = (
|
||||
param.quant_state._orig_to # pylint: disable=protected-access
|
||||
)
|
||||
param.quant_state._orig_to = None # pylint: disable=protected-access
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False, config_only=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||
|
||||
@@ -829,6 +1048,10 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.use_rslora
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
@@ -846,6 +1069,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
model_kwargs: Any = {}
|
||||
@@ -861,6 +1089,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
if rank == 0:
|
||||
model.print_trainable_parameters()
|
||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
@@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
LOG.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||
"SYSTEM": "system",
|
||||
"USER": "human",
|
||||
"ASSISTANT": "gpt",
|
||||
"FUNCTION RESPONSE": "tool",
|
||||
}
|
||||
|
||||
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
|
||||
|
||||
|
||||
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Converts a ChatML formatted row to a list of messages in ShareGPT format.
|
||||
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
|
||||
"""
|
||||
|
||||
system_prompt = row.get("system")
|
||||
if system_prompt:
|
||||
system_prompt = system_prompt.removeprefix("SYSTEM: ")
|
||||
|
||||
chat_str = row["chat"]
|
||||
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
|
||||
|
||||
chat_msg_dicts = [
|
||||
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
|
||||
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
|
||||
]
|
||||
|
||||
if system_prompt:
|
||||
chat_msg_dicts = [
|
||||
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
|
||||
] + chat_msg_dicts
|
||||
|
||||
return chat_msg_dicts
|
||||
|
||||
|
||||
def merge_consecutive_messages(messages):
|
||||
"""
|
||||
Merge consecutive messages from the same sender into a single message.
|
||||
This can be useful with datasets that contain multiple consecutive tool calls.
|
||||
"""
|
||||
|
||||
merged_messages = []
|
||||
current_from = None
|
||||
current_message = ""
|
||||
|
||||
for msg in messages:
|
||||
if current_from == msg["from"]:
|
||||
current_message += msg["value"]
|
||||
else:
|
||||
if current_from is not None:
|
||||
merged_messages.append({"from": current_from, "value": current_message})
|
||||
current_from = msg["from"]
|
||||
current_message = msg["value"]
|
||||
|
||||
if current_from is not None:
|
||||
merged_messages.append({"from": current_from, "value": current_message})
|
||||
|
||||
return merged_messages
|
||||
|
||||
@@ -255,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader) // batch_size
|
||||
data_loader_len = len(data_loader) // cfg.batch_size
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
|
||||
@@ -57,9 +57,9 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 10,
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
|
||||
@@ -43,7 +43,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"val_set_size": 0.2,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
|
||||
@@ -7,6 +7,8 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli import load_rl_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="doesn't seem to work on modal")
|
||||
class TestDPOLlamaLora(unittest.TestCase):
|
||||
"""
|
||||
Test case for DPO Llama models using LoRA
|
||||
|
||||
@@ -7,6 +7,8 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="doesn't seem to work on modal")
|
||||
class TestPhi(unittest.TestCase):
|
||||
"""
|
||||
Test case for Phi2 models
|
||||
|
||||
60
tests/monkeypatch/test_moe.py
Normal file
60
tests/monkeypatch/test_moe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import pytest
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
|
||||
|
||||
def test_fused_mixtral_moe():
|
||||
# NOTE: Requires torch 2.2.0
|
||||
# Set random seeds for reproducibility
|
||||
torch.set_default_dtype(torch.float16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Define the configuration for the MixtralSparseMoeBlock
|
||||
config = MixtralConfig(
|
||||
hidden_size=128,
|
||||
intermediate_size=512,
|
||||
num_local_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
)
|
||||
|
||||
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||
sparse_moe = SparseMoeBlock(
|
||||
experts=mixtral_moe.experts,
|
||||
gate=mixtral_moe.gate,
|
||||
hidden_dim=config.hidden_size,
|
||||
ffn_dim=config.intermediate_size,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok
|
||||
)
|
||||
|
||||
assert torch.cat([
|
||||
mixtral_moe.experts[0].w1.weight.data,
|
||||
mixtral_moe.experts[0].w3.weight.data], dim=0
|
||||
).equal(sparse_moe.experts.experts.weight[0])
|
||||
|
||||
# Generate random input data
|
||||
batch_size = 16
|
||||
sequence_length = 32
|
||||
input_data = torch.randn(batch_size, sequence_length, config.hidden_size)
|
||||
|
||||
# Run the forward pass with gradients for both models
|
||||
with torch.no_grad():
|
||||
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
||||
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
||||
|
||||
# Compute the difference between the outputs
|
||||
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||
|
||||
# Define the tolerance for the difference
|
||||
tolerance = 0.05
|
||||
|
||||
# # Check if the difference is within the tolerance
|
||||
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||
116
tests/prompt_strategies/test_raw_io.py
Normal file
116
tests/prompt_strategies/test_raw_io.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Test module for raw i/o data for prompts
|
||||
"""
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.input_output import (
|
||||
RawInputOutputPrompter,
|
||||
RawInputOutputStrategy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="segments_dataset")
|
||||
def fixture_sharegpt_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": False,
|
||||
"text": "<s>hello ",
|
||||
},
|
||||
{
|
||||
"label": True,
|
||||
"text": "hi there.<eot>",
|
||||
},
|
||||
{
|
||||
"label": False,
|
||||
"text": "goodbye ",
|
||||
},
|
||||
{
|
||||
"label": True,
|
||||
"text": "farewell<eot>",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestRawInputOutputPrompts:
|
||||
"""
|
||||
Test class for raw i/o prompter
|
||||
"""
|
||||
|
||||
def test_segment_prompts(self, segments_dataset, tokenizer):
|
||||
strategy = RawInputOutputStrategy(
|
||||
RawInputOutputPrompter(),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, segments_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
|
||||
assert (
|
||||
tokenizer.decode(input_ids)
|
||||
== "<s> hello hi there.<eot> goodbye farewell<eot>"
|
||||
)
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
1, # <s>
|
||||
6312, # hell
|
||||
28709, # o
|
||||
28705, #
|
||||
12014, # hi
|
||||
736, # there
|
||||
28723, # .
|
||||
32000, # <eot>
|
||||
1179, # good
|
||||
17664, # bye
|
||||
28705, #
|
||||
19111, # fare
|
||||
5458, # well
|
||||
32000, # <eot>
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # <s>
|
||||
-100, # hell
|
||||
-100, # o
|
||||
-100, #
|
||||
12014, # hi
|
||||
736, # there
|
||||
28723, # .
|
||||
32000, # <eot>
|
||||
-100, # good
|
||||
-100, # bye
|
||||
-100, #
|
||||
19111, # fare
|
||||
5458, # well
|
||||
32000, # <eot>
|
||||
]
|
||||
# fmt: on
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test module for sharegpt integration w chatml
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
@@ -8,6 +9,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
SimpleShareGPTPromptTokenizingStrategy,
|
||||
register_chatml_template,
|
||||
)
|
||||
@@ -48,6 +50,18 @@ def fixture_sharegpt_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="glaive_dataset")
|
||||
def fixture_sharegpt_glaive_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
@@ -156,3 +170,29 @@ class TestSharegpt:
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, glaive_dataset, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase):
|
||||
), "DictDefault should support in operator for existing keys in list"
|
||||
|
||||
def test_dict_or_operator(self):
|
||||
cfg = DictDefault(
|
||||
cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
|
||||
|
||||
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"key_a": {"key_b": "value_a"},
|
||||
"key_c": "value_c",
|
||||
@@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
|
||||
)
|
||||
|
||||
assert (
|
||||
cfg.key_a.key_b == "value_b"
|
||||
), "DictDefault should support OR operator for existing nested keys"
|
||||
|
||||
285
tests/test_freeze.py
Normal file
285
tests/test_freeze.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
This module contains unit tests for the `freeze_layers_except` function.
|
||||
|
||||
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
||||
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
|
||||
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
|
||||
|
||||
class TestFreezeLayersExcept(unittest.TestCase):
|
||||
"""
|
||||
A test case class for the `freeze_layers_except` function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.model = _TestModel()
|
||||
|
||||
def test_freeze_layers_with_dots_in_name(self):
|
||||
freeze_layers_except(self.model, ["features.layer"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_without_dots_in_name(self):
|
||||
freeze_layers_except(self.model, ["classifier"])
|
||||
self.assertFalse(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_regex_patterns(self):
|
||||
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
||||
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_all_layers_frozen(self):
|
||||
freeze_layers_except(self.model, [])
|
||||
self.assertFalse(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be frozen.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_all_layers_unfrozen(self):
|
||||
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be trainable.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_start_end(self):
|
||||
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_single_index(self):
|
||||
freeze_layers_except(self.model, ["features.layer[5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
||||
freeze_layers_except(self.model, ["features.layer[:5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_included(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
||||
freeze_layers_except(
|
||||
self.model,
|
||||
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_gradient_output(self, expected):
|
||||
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
||||
|
||||
self.model.features.layer.weight.grad = None # Reset gradients
|
||||
output = self.model.features.layer(input_tensor)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
expected_grads = torch.tensor(expected)
|
||||
torch.testing.assert_close(
|
||||
self.model.features.layer.weight.grad, expected_grads
|
||||
)
|
||||
|
||||
|
||||
class _SubLayerModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(10, 10)
|
||||
|
||||
|
||||
class _TestModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = _SubLayerModule()
|
||||
self.classifier = nn.Linear(10, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -25,20 +25,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
def test_lr_as_float(self):
|
||||
cfg = (
|
||||
self._get_base_cfg()
|
||||
| DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"learning_rate": "5e-5",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
assert cfg.learning_rate == 0.00005
|
||||
|
||||
def test_base_model_config_set_when_empty(self):
|
||||
cfg = self._get_base_cfg()
|
||||
del cfg.base_model_config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for testing prompt tokenizers."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
@@ -18,6 +19,7 @@ from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
@@ -204,13 +206,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
# fmt: off
|
||||
# System message, multi-turn conversations
|
||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# System message, single-turn conversations
|
||||
st_ids = tokenize(test_data['single_turn_sys'])
|
||||
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||
@@ -266,6 +268,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_glaive_tool_label_ignore(self):
|
||||
conversation = {
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(13566) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user