Compare commits

..

4 Commits

Author SHA1 Message Date
Wing Lian
39ad38a1fb update address and port for spaces 2024-02-08 17:55:44 -05:00
Mads Henrichsen
ddb60883f5 create config 2024-02-08 09:26:58 +01:00
Mads Henrichsen
a5724ef08d axolotl start training 2024-02-07 18:16:21 +01:00
Mads Henrichsen
190930b5df spaces ui 2024-02-07 15:52:30 +01:00
130 changed files with 957 additions and 6253 deletions

View File

@@ -59,7 +59,6 @@ body:
label: Config yaml
description: |
Please attach the config yaml!
render: yaml
- type: textarea
id: possible-solution

View File

@@ -7,11 +7,16 @@ jobs:
build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner
runs-on: self-hosted
strategy:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"

View File

@@ -17,6 +17,6 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

View File

@@ -9,16 +9,21 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121
cuda_version: 12.1.0
@@ -30,7 +35,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: axolotl-gpu-runner
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -51,17 +56,27 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
load: true
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: Push to Docker Hub
if: github.event_name != 'pull_request'
run: |
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
if [ -n "$latest_tag" ]; then
docker push "$latest_tag"
fi
build-axolotl-runpod:
needs: build-axolotl
@@ -70,6 +85,11 @@ jobs:
strategy:
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
@@ -86,7 +106,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: axolotl-gpu-runner
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
@@ -33,7 +33,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
python_version: ["3.9", "3.10", "3.11"]
timeout-minutes: 10
steps:
@@ -58,8 +58,8 @@ jobs:
docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
runs-on: [self-hosted, gpu, docker]
timeout-minutes: 30
needs: [pre-commit, pytest]
strategy:
@@ -69,32 +69,44 @@ jobs:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1
pytorch: 2.0.1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
num_gpus: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
python-version: "3.10"
- name: Install Modal
images: winglian/axolotl-tests
- name: Build Docker image
run: |
python -m pip install --upgrade pip
pip install modal jinja2
- name: Update env vars
# Set up build arguments
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
CUDA="${{ matrix.cuda }}"
PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image
docker build . \
--file ./docker/Dockerfile-tests \
--build-arg BASE_TAG=$BASE_TAG \
--build-arg CUDA=$CUDA \
--build-arg GITHUB_REF=$GITHUB_REF \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
--tag ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} \
--no-cache
- name: Unit Tests w docker image
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: GPU Unit Tests w docker image
run: |
modal run cicd.tests
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
- name: GPU Unit Tests monkeypatched w docker image
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}

5
.gitignore vendored
View File

@@ -167,8 +167,3 @@ cython_debug/
# WandB
# wandb creates a folder to store logs for training runs
wandb
# Runs
lora-out/*
qlora-out/*
mlruns/*

View File

@@ -1,5 +1,5 @@
[mypy]
plugins = pydantic.mypy
exclude = venv
[mypy-alpaca_lora_4bit.*]
@@ -32,9 +32,6 @@ ignore_missing_imports = True
[mypy-bitsandbytes]
ignore_missing_imports = True
[mypy-requests]
ignore_missing_imports = True
[mypy-datasets]
ignore_missing_imports = True

View File

@@ -31,7 +31,6 @@ repos:
additional_dependencies:
[
'types-PyYAML',
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5

227
README.md
View File

@@ -22,11 +22,11 @@ Features:
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Environment](#environment)
- [Installation](#installation)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
- [LambdaLabs](#lambdalabs)
- [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset)
@@ -34,7 +34,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference-playground)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- Advanced Topics
@@ -87,17 +87,15 @@ Features:
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash
@@ -105,18 +103,9 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
```
General case:
```
pip3 install -e '.[flash-attn,deepspeed]'
```
Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
```
pip3 install -e '.'
```
### Usage
```bash
# preprocess datasets - optional but recommended
@@ -132,20 +121,15 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
```
## Advanced Setup
## Installation
### Environment
#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
```
Or run on the current files for development:
@@ -164,7 +148,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAcc
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
It additionally:
@@ -179,7 +163,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
</details>
#### Conda/Pip venv
1. Install python >=**3.10**
1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -198,14 +182,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
##### LambdaLabs
#### LambdaLabs
<details>
<summary>Click to Expand</summary>
@@ -213,11 +192,11 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
1. Install python
```bash
sudo apt update
sudo apt install -y python3.10
sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.9
```
@@ -255,18 +234,15 @@ Please use WSL or Docker!
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
@@ -276,32 +252,31 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
#### Pretraining
- `completion`: raw corpus
```json
{"text": "..."}
```
Note: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```yaml
pretraining_dataset: # hf path only
```
#### Supervised finetuning
##### Instruction
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
- `completion`: raw corpus
```json
{"text": "..."}
```
<details>
@@ -379,37 +354,14 @@ pretraining_dataset: # hf path only
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
</details>
##### Template-Free
- `input_output`: template-free prompt construction
```json
{"segments": [{"label": true|false, "text": "..."}]}
```
This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details.
##### Conversation
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
<details>
<summary>See other formats</summary>
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
@@ -425,8 +377,6 @@ This is a special format that allows you to construct prompts without using temp
</details>
Note: `type: sharegpt` opens a special config `conversation:` that enables conversions to many Conversation types. See dataset section under [all yaml options](#all-yaml-options).
#### How to add custom prompts
For a dataset that is preprocessed for instruction purposes:
@@ -448,16 +398,12 @@ datasets:
format: "[INST] {instruction} [/INST]"
no_input_format: "[INST] {instruction} [/INST]"
```
See full config options under [all yaml options](#all-yaml-options).
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
```yaml
- path: ...
```
### Config
@@ -471,18 +417,22 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- dataset
```yaml
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca
sequence_len: 2048 # max token length for prompt
# huggingface repo with specific configuration/subset
# huggingface repo
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
# huggingface repo with specific configuration/subset
datasets:
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
@@ -490,42 +440,39 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ...
type: sharegpt
conversation: chatml # default: vicuna_v1.1
conversation: chatml
# local
# local
datasets:
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
# dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
@@ -533,7 +480,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # 'qlora' or leave blank for full finetune
adapter: lora # qlora or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
@@ -542,9 +489,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- v_proj
```
<details id="all-yaml-options">
<details>
<summary>All yaml options (click to expand)</summary>
<summary>All yaml options (click me)</summary>
```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -556,8 +503,8 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub
revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
@@ -574,16 +521,15 @@ tokenizer_legacy:
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# (Internal use only)
# Used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_qwen_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
overrides_of_model_config:
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
@@ -600,6 +546,8 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
@@ -673,7 +621,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
# use RL training: dpo, ipo, kto_pair
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -693,7 +641,7 @@ dataset_processes: # defaults to os.cpu_count() if not set
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # private repo path to push finetuned model
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
@@ -772,8 +720,6 @@ peft:
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
@@ -789,7 +735,6 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -823,8 +768,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -853,11 +797,14 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim
lr_div_factor: # Learning rate div factor
# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
@@ -1029,9 +976,6 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
@@ -1080,10 +1024,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
@@ -1145,7 +1085,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
@@ -1206,7 +1146,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
@@ -1253,28 +1193,13 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
@@ -1303,6 +1228,4 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
#### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
---

View File

@@ -1,39 +0,0 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,5 +0,0 @@
#!/bin/bash
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -1,75 +0,0 @@
"""
modal application to run axolotl gpu tests in Modal
"""
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -16,7 +16,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -20,7 +20,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,7 +24,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,7 +24,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -2,6 +2,7 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -3,10 +3,9 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi
# So we can test the Docker image

View File

@@ -7,8 +7,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"

View File

@@ -3,10 +3,9 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ARG PYTORCH_VERSION="2.0.1"
ARG GITHUB_REF="main"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -25,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi
# So we can test the Docker image

View File

@@ -74,6 +74,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).
```bash
### Configuration

View File

@@ -1,37 +0,0 @@
# FDSP + QLoRA
## Background
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
Below, we describe how to use this feature in Axolotl.
## Usage
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> ![Tip]
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
## References
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
- Related HuggingFace PRs Enabling FDSP + QLoRA:
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.

View File

@@ -1,260 +0,0 @@
# Template-free prompt construction with the `input_output` format
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates (I usually don't!). This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ingored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```

View File

@@ -1,18 +0,0 @@
# Mac M series support
Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
Current support:
- [x] Support for all models
- [x] Full training of models
- [x] LoRA training
- [x] Sample packing
- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
- [ ] xformers
- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
- [ ] qlora
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -177,24 +177,6 @@
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false

View File

@@ -1,65 +0,0 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,4 +1,5 @@
base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_disable_exllama: true
model_type: AutoModelForCausalLM

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -59,7 +60,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -56,7 +57,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,70 +0,0 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,6 +2,7 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -60,7 +61,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -48,7 +49,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,79 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./lora-out
eval_sample_packing: false
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
sdp_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,74 +0,0 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:

View File

@@ -16,12 +16,12 @@ output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - ^lm_head.weight$
# - ^model.embed_tokens.weight$[:32000]
# - model.layers.2[0-9]+.block_sparse_moe.gate
# - model.layers.2[0-9]+.block_sparse_moe.experts
# - model.layers.3[0-9]+.block_sparse_moe.gate
# - model.layers.3[0-9]+.block_sparse_moe.experts
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
model_config:
output_router_logits: true
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,75 +0,0 @@
import gc
import torch
from tqdm import tqdm
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
from transformers import AutoTokenizer, TextStreamer
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
memory_pct = (
memory_used
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
* 100
)
return memory_pct
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Load model
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
model = MixtralForCausalLM.from_pretrained(
model_path,
config=config,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
with tqdm(modules.items(), desc="scatter moe") as pbar:
for i, (name, module) in enumerate(pbar):
smoe = SparseMoeBlock(
experts=module.experts,
gate=module.gate,
hidden_dim=module.hidden_dim,
ffn_dim=module.ffn_dim,
num_experts=module.num_experts,
top_k=module.top_k,
)
old_module = model.model.layers[i].block_sparse_moe
setattr(model.model.layers[i], "block_sparse_moe", smoe)
del old_module
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
for device_index in range(torch.cuda.device_count()):
device_memory_pct = compute_memory_used_pct(device_index)
print(device_index, device_memory_pct)
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -67,7 +68,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,69 +0,0 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,66 +0,0 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,36 +0,0 @@
# StableLM 2
This repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.
## Estimating GPU Requirements
| type | deepspeed | batch size | context length | vRAM GPU (GBs) |
|---------------|-----------|------------|----------------|----------------|
| full finetune | N/A | 1 | 4096 | ~21.5GBs |
| full finetune | zero2 | 1 | 4096 | ~20GBs |
| lora | N/A | 1 | 4096 | ~16.6GBs |
The above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).
This blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html
## Training
We have example scripts here for both full finetuning and lora using the popular alpaca dataset:
```shell
# preprocess the dataset
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml
```
Single GPU Training:
```shell
python -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json
# OR
python -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml
```
Multinode GPU Training with `accelerate`:
```shell
# make sure you've configured accelerate properly
accelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json
```

View File

@@ -1,69 +0,0 @@
base_model: bigcode/starcoder2-3b
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.2
output_dir: ./qlora
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_steps:
eval_table_size:
saves_per_epoch: 4
save_steps:
save_total_limit: 2
debug:
deepspeed:
weight_decay:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,64 +0,0 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -15,7 +16,6 @@ output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora

View File

@@ -2,6 +2,7 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,8 @@
base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -28,7 +29,7 @@ num_epochs: 1
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1

View File

@@ -1,4 +1,3 @@
pre-commit
black
mypy
types-requests

View File

@@ -1,18 +1,16 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.9.0
transformers==4.38.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
tokenizers==0.15.0
bitsandbytes>=0.43.0
bitsandbytes>=0.41.1
accelerate==0.26.1
deepspeed==0.13.1
pydantic==2.6.3
deepspeed>=0.13.1
addict
fire
PyYAML>=6.0
requests
datasets>=2.15.0
flash-attn==2.5.5
flash-attn==2.3.3
sentencepiece
wandb
einops
@@ -22,13 +20,14 @@ hf_transfer
colorama
numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.1
evaluate==0.4.0
scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.36
fschat==0.2.34
gradio==3.50.2
tensorboard
@@ -40,4 +39,3 @@ gcsfs
# adlfs
trl>=0.7.9
fastcore>=1.5.29

View File

@@ -1,7 +1,5 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
@@ -18,7 +16,6 @@ def parse_requirements():
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -29,25 +26,11 @@ def parse_requirements():
_install_requires.append(line)
try:
if "Darwin" in platform.system():
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
else:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass
@@ -68,13 +51,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.5.5",
"flash-attn==2.5.0",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.13.1",
"deepspeed>=0.13.1",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -83,11 +66,5 @@ setup(
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
},
)

View File

@@ -1,19 +1,16 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests
import gradio as gr
import torch
import yaml
@@ -23,7 +20,6 @@ from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -63,52 +59,6 @@ def print_axolotl_text_art(suffix=None):
print(ascii_art)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
@@ -214,8 +164,6 @@ def do_inference_gradio(
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
@@ -322,14 +270,14 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(Path(config))
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
@@ -342,22 +290,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
)
validate_config(cfg)
prepare_optim_env(cfg)

View File

@@ -3,7 +3,6 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -24,7 +23,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,7 +3,6 @@ CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -26,7 +25,7 @@ def shard(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple, Union
from typing import Tuple
import fire
from transformers.hf_argparser import HfArgumentParser
@@ -25,7 +25,7 @@ from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))

View File

@@ -1,55 +0,0 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

View File

@@ -5,10 +5,8 @@ Builder for the training args and trainer
import abc
import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
@@ -18,10 +16,7 @@ from typing import List, Optional, Type, Union
import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -31,21 +26,17 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
@@ -58,12 +49,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
@@ -72,10 +59,6 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
@@ -147,10 +130,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
@@ -163,9 +142,6 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
@@ -183,23 +159,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
class AxolotlTrainer(Trainer):
@@ -224,33 +183,6 @@ class AxolotlTrainer(Trainer):
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -284,16 +216,6 @@ class AxolotlTrainer(Trainer):
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
@@ -477,56 +399,6 @@ class AxolotlTrainer(Trainer):
return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()
if self.args.qlora is False:
return res
# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']
if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
return res
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -750,11 +622,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
if self.cfg.use_mlflow:
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
@@ -774,11 +642,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
@@ -846,9 +709,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
@@ -930,8 +790,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
@@ -992,10 +850,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and not self.cfg.test_datasets
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
@@ -1016,10 +872,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
@@ -1030,9 +882,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -1050,20 +899,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
@@ -1075,42 +913,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
trainer_kwargs = {}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
@@ -1180,7 +994,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -30,7 +30,6 @@ class ColorfulFormatter(Formatter):
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",

View File

@@ -1,133 +0,0 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -106,7 +106,7 @@ def get_turns( # pylint: disable=too-many-return-statements
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(

View File

@@ -44,18 +44,6 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_swiglu_available() -> bool:
from xformers.ops.common import get_xformers_operator
try:
get_xformers_operator("swiglu_packedw")()
return True
except RuntimeError as exc:
if "No such operator xformers::swiglu_packedw " in str(exc):
return False
return True
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
@@ -287,9 +275,7 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -439,9 +425,7 @@ def flashattn_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -704,9 +688,6 @@ def llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions

View File

@@ -2,6 +2,9 @@
Patches to support multipack for mixtral
"""
import torch
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def patch_mixtral_moe_forward_zero3() -> None:
@@ -48,3 +51,11 @@ def patch_mixtral_moe_forward_zero3() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -1,149 +0,0 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import torch
import torch.nn as nn
from axolotl.monkeypatch.moe import ops
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx, x, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates=None, grouped_in=False, grouped_out=False,
):
output = ops.scatter2scatter(
X=x, W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
padded_block_idxs=padded_block_idxs,
k=k, x_grouped=grouped_in, y_grouped=grouped_out
)
if gates is not None:
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
output = torch.bmm(
gates[:, None, :],
output_expanded
).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x, expert_weights,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates,
output_expanded
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out):
(x, expert_weights,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates, output_expanded) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
# print("backward")
if gates is not None:
# calculate gates gradient
d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# print("expanded and grouping")
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = ops.group(grad_out, sorted_scattered_idxs,
fan_out=gate_fan, coeff=gates_flat,
out=grouped_grad_out)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights = ops.group_bwd_W(
DY=grouped_grad_out, X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0)
)
d_expanded_input = ops.scatter2scatter(
X=grouped_grad_out, x_grouped=True,
W=expert_weights.permute(0, 2, 1),
padded_block_idxs=padded_block_idxs,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
# print("backward end.")
return (
# x, expert_weights, k,
d_input, d_weights, None,
# sorted_expert_idxs, sorted_scattered_idxs,
None, None,
# padded_block_idxs, expert_offsets,
None, None,
# gates
d_gates, None, None
)
def parallel_linear(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates=None):
results = ParallelLinear.apply(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets, gates)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, device) -> None:
super().__init__()
self.weight = nn.Parameter(
torch.empty(num_experts, output_size, input_size, device=device)
)
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
def extra_repr(self):
return 'num_experts={}, input_size={}, output_size={}'.format(
self.num_experts, self.input_size, self.output_size)
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates=None, grouped_in=False, grouped_out=False):
results = ParallelLinear.apply(
inputs, self.weight.permute(0, 2, 1), k,
sorted_expert_idxs, sorted_scattered_idxs,
padded_block_idxs, expert_offsets,
gates, grouped_in, grouped_out
)
return results

View File

@@ -1,86 +0,0 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import gc
import torch
from torch import nn
from axolotl.monkeypatch.moe import ops
from axolotl.monkeypatch.moe.linear import ParallelExperts
class FusedExperts(nn.Module):
def __init__(
self,
experts: nn.ModuleList =None,
hidden_dim=128,
ffn_dim=512,
num_experts=8,
top_k=2,
activation=nn.SiLU(),
):
"""
This implements fused experts that are compatible with Mixtral.
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
"""
super(FusedExperts, self).__init__()
device = experts[0].w1.weight.device
self.num_experts = num_experts
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
with torch.no_grad():
for i in range(len(experts)):
self.experts.weight.data[i].copy_(
torch.cat(
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
dim=0
)
)
self.output_experts.weight.data[i].copy_(
experts[i].w2.weight.detach()
)
def forward(
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
):
x_shape = x.size()
x = x.view(-1, x_shape[-1])
with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
selected_experts
)
padded_block_idxs, expert_offsets = ops.padded_block_indices(
sorted_expert_idxs, self.num_experts
)
h, gates = self.experts(
x,
self.top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
grouped_out=True,
).chunk(2, dim=-1)
h = self.activation(gates) * h
y = self.output_experts(
h,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
grouped_in=True,
gates=routing_weights,
)
y = y.view(*x_shape[:-1], y.size(-1))
return y

View File

@@ -1,50 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from axolotl.monkeypatch.moe.mlp import FusedExperts
class SparseMoeBlock(nn.Module):
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
super().__init__()
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_experts = num_experts
self.top_k = top_k
self.gate = gate
self.experts = FusedExperts(
experts=experts,
hidden_dim=hidden_dim,
ffn_dim=ffn_dim,
num_experts=num_experts,
top_k=top_k,
activation=experts[0].act_fn
)
def _post_training(self, model, name):
# get original weights back: reverse the concat + stack in the fused experts
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
# TODO: recreate MoE class with original weights
experts = []
for i in range(self.num_experts):
pass
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# Fused expert forward
final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

View File

@@ -1,353 +0,0 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import torch
import triton
import triton.language as tl
from torch.nn import functional as F
BLOCK_M = 128
@torch.jit.script
def flatten_and_sort(expert_idxs:torch.Tensor):
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
return sorted_expert_idxs, sorted_scattered_idxs
@torch.jit.script
def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) :
expert_counts = torch.bincount(sorted_experts_idxs, minlength=k)
padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1
padded_expert_block_end = padded_block_counts.cumsum(-1)
expert_boundaries_end = expert_counts.cumsum(-1)
expert_boundaries_start = expert_boundaries_end - expert_counts
padded_expert_block_start = padded_expert_block_end - padded_block_counts
block_idxs = torch.arange(padded_expert_block_end[-1],
dtype=sorted_experts_idxs.dtype,
device=sorted_experts_idxs.device)
block_mask = (
(block_idxs[:, None] < padded_expert_block_start) |
(block_idxs[:, None] >= padded_expert_block_end)
)
expanded_block_idxs = (
N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) +
expert_boundaries_start
)
expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1)
return expanded_block_idxs, expert_boundaries_end
def _scatter2scatter_configs():
return [
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
@triton.heuristics({
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
})
@triton.jit
def _scatter2scatter(
X_ptr, stride_xm, stride_xk,
W_ptr, stride_we, stride_wk, stride_wn,
Y_ptr, stride_ym, stride_yn,
grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
OUT_M: tl.constexpr,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr, y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_range = tl.arange(0, BLOCK_M)
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)
# M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M)
M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
E_idx = tl.min(E_idxs)
E_mask = E_idxs == E_idx
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = M_idx // FAN_OUT
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
K_block = tl.arange(0, BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
# N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
# N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(0, iters):
if NO_K_MASK:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
if NO_N_MASK:
w = tl.load(W_blk_ptrs)
else:
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE)
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])
def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
padded_block_idxs, x_grouped=False, y_grouped=False,
out=None):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
x_dim = X.size(-1)
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
O = out
def grid(META):
grid_num = (
padded_block_idxs.size(0) *
triton.cdiv(META['N'], META['BLOCK_N']),
)
return grid_num
"""
print("X", X.size(), X.stride(),
"W", W.size(), W.stride(),
"O", O.size(), O.stride(),
"sorted_idxs", sorted_scattered_idxs.size(),
"FAN_OUT", k,
"BLOCK_M", BLOCK_M,
"grouped", (x_grouped, y_grouped))
"""
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X, X.stride(0), X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W, W.stride(0), W.stride(1), W.stride(2),
# Y_ptr, stride_ym, stride_yn,
O, O.stride(0), O.stride(1),
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=O.size(1), E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
OUT_M=O.size(0),
allow_tf32=True,
x_grouped=x_grouped, y_grouped=y_grouped,
)
return O
def _config_XtY():
return [
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
]
def group_bwd_W(DY, X, expert_offsets, E):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
def grid(META):
grid = (
E * triton.cdiv(META['K'], META['BLOCK_K']),
triton.cdiv(META['N'], META['BLOCK_N']),
)
return grid
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY, DY.stride(0), DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X, X.stride(0), X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW, DW.stride(0), DW.stride(1), DW.stride(2),
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0), N=DY.size(-1), K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=True
)
return DW
@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
@triton.heuristics({
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
})
@triton.jit
def _groupXtY(
DY_ptr, stride_dym, stride_dyk,
X_ptr, stride_xm, stride_xn,
DW_ptr, stride_dwe, stride_dwk, stride_dwn,
expert_offsets_ptr,
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
def _config_grouping():
return [
triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
# print("grp init:", Y.size())
def grid(META):
grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A, A.stride(0), A.stride(1), coeff is not None, coeff, fan_out,
# Y_ptr, stride_yn, stride_yk,
Y, Y.stride(0), Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N, K
)
return Y
@triton.autotune(configs=_config_grouping(), key=['K'])
@triton.heuristics({
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
})
@triton.jit
def _group(
src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
tgt_ptr, stride_tn, stride_ti,
grouped_idx_ptr,
N: tl.constexpr, K: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK:
block = tl.load(src_blk_ptrs) # , mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -1,66 +0,0 @@
"""
Adapted from:
https://github.com/shawntan/scattermoe
https://arxiv.org/abs/2403.08245
"""
import torch
import triton
import triton.language as tl
from torch.nn import functional as F
@triton.jit
def _single2scatter(
X_ptr, stride_xm, stride_xk,
W_ptr, stride_we, stride_wk, stride_wn,
Y_ptr, stride_ym, stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
x = tl.load(X_blk_ptrs)
w = tl.load(W_blk_ptrs)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc)
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = ydim // BLOCK_N, k
_single2scatter[grid](
X, X.stride(0), X.stride(1),
W, W.stride(0), W.stride(1), W.stride(2),
Y, Y.stride(0), Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim, N=ydim, E=E,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32
)
return Y

View File

@@ -1,61 +0,0 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -46,9 +46,8 @@ def reset_optimizer(
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
n_zeros = 0
n_total = 0
@@ -160,7 +159,6 @@ class ReLoRACallback(TrainerCallback):
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
)
if self.quantized:
@@ -267,7 +265,7 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps - self.warmup_steps:
if step < self.relora_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps

View File

@@ -186,8 +186,8 @@ def mask_2d_to_4d(
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.

View File

@@ -1,78 +0,0 @@
"""
HF Chat Templates prompt strategy
"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

View File

@@ -8,13 +8,14 @@ import logging
LOG = logging.getLogger("axolotl")
def load(strategy, cfg, **kwargs):
def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -5,7 +5,6 @@ DPO strategies for chatml
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
@@ -24,28 +23,8 @@ def argilla(
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
@@ -69,7 +48,7 @@ def icr(
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
@@ -91,9 +70,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
@@ -111,7 +88,7 @@ def prompt_pairs(
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""

View File

@@ -1,41 +0,0 @@
"""
User-defined DPO strategies
"""
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample
return transform_fn

View File

@@ -3,7 +3,7 @@ DPO strategies for zephyr
"""
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
data = {}
data["prompt"] = (

View File

@@ -1,54 +0,0 @@
"""Module for plain input/output prompt pairs"""
from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
class RawInputOutputStrategy(PromptTokenizingStrategy):
"""Prompt Strategy class for input/output pairs"""
def __init__(self, *args, eos_token=None, **kwargs):
super().__init__(*args, **kwargs)
self.eos_token = eos_token
if not eos_token:
self.eos_token = self.tokenizer.eos_token
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
input_ids = []
labels = []
for label, text in self.prompter.build_prompt(prompt["segments"]):
tokenized_output = self.tokenizer(
text, add_special_tokens=False, return_tensors=None
)["input_ids"]
input_ids += tokenized_output
if label or self.train_on_inputs:
labels += tokenized_output
else:
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
class RawInputOutputPrompter(Prompter):
"""prompter for raw i/o data"""
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
for segment in source:
yield segment["label"], segment["text"]
def load(tokenizer, cfg):
return RawInputOutputStrategy(
RawInputOutputPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -1,15 +1,10 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)
def register_chatml_template(system_message=None):
@@ -24,16 +19,6 @@ def register_chatml_template(system_message=None):
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -92,26 +77,12 @@ def load_guanaco(tokenizer, cfg):
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = False
_strict = True
@property
def strict(self):
@@ -125,25 +96,10 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
conversations = prompt["conversations"]
if self.strict:
return conversations
role_key = "from"
if "role" in conversations[0].keys():
role_key = "role"
value_key = "value"
if "text" in conversations[0].keys():
value_key = "text"
elif "content" in conversations[0].keys():
value_key = "content"
# remap roles - allow for assistant turn"
role_map = {
"user": "human",
"human": "human",
"assistant": "gpt",
"gpt": "gpt",
"system": "system",
}
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t[role_key]], "value": t[value_key]}
for t in conversations
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns
@@ -187,15 +143,3 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)
return conversation

View File

@@ -360,19 +360,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
user, assistant = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user_role_label in role:
if user in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -392,7 +384,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role:
elif assistant in role:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -434,8 +426,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

View File

@@ -267,8 +267,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
def __init__(
self,
@@ -276,7 +274,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
@@ -289,8 +286,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
def _build_result(self, source):
if len(source) < 2:
@@ -308,8 +303,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
if self.role_key_tool:
roles[self.role_key_tool] = conv.roles[2]
try:
# Apply prompt templates

View File

@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -99,7 +99,7 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
freeze_parameters_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,
@@ -208,10 +208,7 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
pass
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()

View File

@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
or torch.device(device).type == "meta"
):
return default_value
return func(*args, **kwargs)
return wrapper
@@ -47,12 +47,6 @@ def gpu_memory_usage_all(device=0):
return usage, reserved - usage, max(0, smi - reserved)
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
@@ -69,10 +63,7 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
else:
usage, cache, misc = gpu_memory_usage_all(device)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")

View File

@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List
import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
@@ -41,8 +42,8 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class EvalFirstStepCallback(
@@ -61,6 +62,7 @@ class EvalFirstStepCallback(
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and args.eval_steps < 1.0
and state.global_step == 1
):
control.should_evaluate = True
@@ -359,187 +361,6 @@ def bench_eval_callback_factory(trainer, tokenizer):
return BenchEvalCallback
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score
def evaluate_preds(sources, predictions, references):
scores = {}
for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
@@ -567,7 +388,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
@@ -755,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

Some files were not shown because too many files have changed in this diff Show More