Compare commits

...

34 Commits

Author SHA1 Message Date
Wing Lian
772cd870d4 fix the sed command to replace the version w the tag
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-11 13:44:19 -04:00
Wing Lian
6c5fbe6223 add long_description for pypi push (#555) 2023-09-11 13:34:29 -04:00
Wing Lian
bcbc9597e9 replace tags, build dist for pypi publish (#553)
* replace tags, build dist for pypi publish

* missing trailing comma
2023-09-11 13:25:41 -04:00
The Objective Dad
6d57f2f0f0 ergonomic update to optimizer config doc (#548) 2023-09-11 12:35:45 -04:00
Wing Lian
20ed4c1f9e pypi on tag push (#552) 2023-09-11 10:33:42 -04:00
Wing Lian
c5dedb17ad remove with section, doesn't seem to work (#551) 2023-09-11 10:27:17 -04:00
Wing Lian
b56503d423 publish to pypi workflow on tagged release (#549) 2023-09-11 09:44:47 -04:00
Wing Lian
a94f9cb99e fix for quant config from model (#540) 2023-09-10 12:40:52 -04:00
dongxiaolong
c1921c9acb Update requirements.txt (#543)
fix fsdp
2023-09-08 16:07:11 -04:00
Wing Lian
0b4cf5bc8c workaround for md5 variations (#533)
* workaround for md5 variations

* refactor the prepared hash too
2023-09-08 16:01:05 -04:00
SlapDrone
78ee2cdab2 add git environment variables to compose: avoid checkout failure error 128 on build (#534) 2023-09-08 15:59:49 -04:00
Wing Lian
34c0a86a11 update readme to point to direct link to runpod template, cleanup install instrucitons (#532)
* update readme to point to direct link to runpod template, cleanup install instrucitons

* default install flash-attn and auto-gptq now too

* update readme w flash-attn extra

* fix version in setup
2023-09-08 11:58:54 -04:00
The Objective Dad
5e2d8a42d9 Adding NCCL Timeout Guide (#536)
* fixes NCCL_P2P_LEVEL=NVL #429

* adding more insights into verious values of NCCL_P2P_LEVEL
2023-09-08 11:57:47 -04:00
Wing Lian
e30f1e3cf7 Early stopping metric (#537)
* set early stopping metric to check

* tweak how load_best_model_at_end gets set for early stopping

* add validation for earl;y stopping patience

* remove negation

* save results to metrics in callback

* move early stopping callback after the benchmark evals

* broadcast metrics so early stopping works
2023-09-08 11:57:02 -04:00
Wing Lian
343714972b recommend padding when using sample packing (#531) 2023-09-06 17:00:21 -04:00
Wing Lian
245c5c41e2 log rank too (#527) 2023-09-06 08:37:51 -04:00
Wing Lian
a546ca2813 misc fixes/improvements (#513)
fix per pr feedback
2023-09-05 16:40:13 -04:00
Wing Lian
3355706e22 Add support for GPTQ using native transformers/peft (#468)
* auto gptq support

* more tweaks and add yml

* remove old gptq docker

* don't need explicit peft install for tests

* fix setup.py to use extra index url

install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd

* gptq doesn't play well with sample packing

* address pr feedback

* remove torch install for now

* set quantization_config from model config

* Fix the implementation for getting quant config from model config
2023-09-05 12:43:22 -04:00
mhenrichsen
daa4faca12 Merge pull request #520 from bdashore3/sharegpt-fixes
Allow for custom system prompts with ShareGPT
2023-09-05 09:02:55 +02:00
Aman Karmani
fc8766e502 reorg a bit 2023-09-05 02:21:24 +00:00
Aman Gupta Karmani
72a6fe1c1f use flash_attn rmsnorm when available (#526)
* use flash_attn xentropy when available

* use flash_attn.ops.rms_norm when available

* log when xentropy is not found

* log how to install RMSNorm

* add quotes so pip install works
2023-09-04 19:44:51 -04:00
Aman Gupta Karmani
5fe30b1497 use flash_attn xentropy when available (#525)
* use flash_attn xentropy when available

* log when xentropy is not found
2023-09-04 17:49:16 -04:00
Aman Gupta Karmani
44454ae4c4 move is_llama_derived_model into normalize_config (#524) 2023-09-04 00:19:03 -04:00
Wing Lian
09f154397e No gather single gpu (#523)
* don't attempt to gather on multi-gpu

* also check distributed status in bench callback
2023-09-03 23:24:28 -04:00
kingbri
995557bdf3 Prompters: ShareGPT: Allow for custom system prompts
If a system prompt is present in a conversation, add it instead of
using the default.

Signed-off-by: kingbri <bdashore3@proton.me>
2023-09-01 13:53:05 -04:00
Maxime
1991946c5a fix: bad dtype for full finetune (#504)
* fix: bad dtype for full finetune

* Update src/axolotl/utils/models.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update models.py

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-09-01 07:11:45 -07:00
NanoCode012
f51c9c56c6 Fix(doc): Inform Windows users to use WSL/docker (#518) 2023-09-01 00:08:21 -07:00
Wing Lian
7710e81f50 log supervised token count (#448) 2023-08-31 15:45:23 -07:00
Tom Jobbins
48434bec54 Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511) 2023-08-31 14:26:52 -07:00
Jan Philipp Harries
396a7a74fc Added advanced DDP args (#515)
* add ddp_config

* add advanced ddp config

* add ddp_config

* add advanced ddp config

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-08-31 10:37:47 -07:00
Wing Lian
b21e4a20fe split train from other cli options (#503) 2023-08-30 22:01:47 -07:00
Alpay Ariyak
42f9642792 Changed Bench Eval to report metrics correctly by split. Added total accuracy and renamed previously used bench_accuracy to bench_average_accuracy. (#512)
* Added "eval_" prefix

* Added total bench accuracy and renamed the previous one to bench_average_accuracy. Changed naming to use bench_split instead of always using eval_ prefix.
2023-08-30 22:00:50 -07:00
Wing Lian
c56b450cf5 drop empty tokenized rows too (#509) 2023-08-30 06:55:26 -07:00
Aman Gupta Karmani
1e07c162f1 set zero3 optimizer betas to auto so they inherit from HF trainer config (#507) 2023-08-30 08:10:33 -04:00
38 changed files with 823 additions and 453 deletions

View File

@@ -23,11 +23,6 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras: gptq
runs-on: self-hosted
steps:
- name: Checkout
@@ -73,11 +68,6 @@ jobs:
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
axolotl_extras: gptq
runs-on: self-hosted
steps:
- name: Checkout

45
.github/workflows/pypi.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: publish pypi
on:
push:
tags:
- '*'
jobs:
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip3 install wheel
pip3 install -e .
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py
run: >-
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a binary wheel
run: >-
python setup.py sdist bdist_wheel
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -24,8 +24,8 @@ jobs:
- name: Install dependencies
run: |
pip install -e .[peft]
pip install -r requirements-tests.txt
pip3 install -e .
pip3 install -r requirements-tests.txt
- name: Run tests
run: |

View File

@@ -90,8 +90,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
```
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
Or run on the current files for development:
@@ -104,19 +103,9 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following:
- Recommended, supports QLoRA, NO gptq/int4 support
3. Install axolotl along with python dependencies
```bash
pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
```
- gptq/int4 support, NO QLoRA
```bash
pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
pip3 install -e .[flash-attn]
```
- LambdaLabs
@@ -151,10 +140,9 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install -e . # change depend on needs
pip3 install -e .
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
```
5. Set path
@@ -163,6 +151,8 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```
</details>
- Windows: Please use WSL or Docker!
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
@@ -570,6 +560,30 @@ log_sweep_min_lr:
log_sweep_max_lr:
# specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
#
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_hf
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adafactor
# - adamw_anyprecision
# - sgd
# - adagrad
# - adamw_bnb_8bit
# - lion_8bit
# - lion_32bit
# - paged_adamw_32bit
# - paged_adamw_8bit
# - paged_lion_32bit
# - paged_lion_8bit
optimizer:
# specify weight decay
weight_decay:
@@ -623,6 +637,11 @@ fsdp_config:
# Deepspeed config path
deepspeed:
# Advanced DDP Arguments
ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:
@@ -745,6 +764,10 @@ Try to turn off xformers.
It's safe to ignore it.
> NCCL Timeouts during training
See the [NCCL](docs/nccl.md) guide.
## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you

View File

@@ -35,10 +35,7 @@
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.95
],
"betas": "auto",
"eps": 1e-8,
"weight_decay": "auto"
}

View File

@@ -9,6 +9,11 @@ services:
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
# Set environment variables
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
- WANDB_API_KEY=${WANDB_API_KEY}
deploy:
resources:

View File

@@ -11,7 +11,6 @@ RUN apt-get update && \
WORKDIR /workspace
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \

46
docs/nccl.md Normal file
View File

@@ -0,0 +1,46 @@
# NCCL
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
```text
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
```
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
```shell
nvidia-smi nvlink --status
```
To force NCCL to use NVLink, simply set this in the environment:
```shell
export NCCL_P2P_LEVEL=NVL
```
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
| NCCL_P2P_LEVEL | Description |
| -------------- | ----------- |
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
```shell
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
```shell
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
```
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.

View File

@@ -17,6 +17,7 @@ output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 100000
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -17,6 +17,7 @@ output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 100000
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -17,6 +17,7 @@ output_dir: ./lora-out
sequence_len: 100000
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 100000
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -1,8 +0,0 @@
# LLaMa 7B using LoRA
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
```

View File

@@ -1,63 +0,0 @@
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code:
load_in_8bit: true
gptq: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./llama-7b-lora-int4
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
fp16: true
bf16: false
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 5
xformers_attention:
flash_attention:
gradient_checkpointing: true
gptq_groupsize: 128
gptq_model_v1: false
warmup_steps: 20
eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
tokens:
pad_token: "<pad>"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,76 @@
base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_bits: 4
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- k_proj
- o_proj
- q_proj
- v_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000017
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100
eval_steps:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -17,6 +17,7 @@ output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16

View File

@@ -1,14 +1,18 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
accelerate @ git+https://github.com/huggingface/accelerate
addict
evaluate
fire
PyYAML>=6.0
datasets
flash-attn>=2.0.8
flash-attn>=2.2.1
sentencepiece
wandb
einops

View File

@@ -4,9 +4,7 @@ import importlib
import logging
import os
import random
import signal
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
@@ -17,17 +15,17 @@ import yaml
# add src to the pythonpath so we don't need to pip install this
from art import text2art
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model, load_model_config, load_tokenizer
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -40,26 +38,13 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process():
print(ascii_art)
@@ -73,9 +58,45 @@ def get_multi_line_input() -> Optional[str]:
return instruction
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
if prompter == "None":
prompter = None
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
@@ -176,141 +197,6 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
if not (
cli_args.shard or cli_args.merge_lora or cli_args.inference
): # don't need to load dataset for these
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
)
if cli_args.prepare_ds_only:
LOG.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if cli_args.merge_lora and cfg.adapter is not None:
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cli_args.inference:
LOG.debug("Running inference on model")
do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
return
if cli_args.shard:
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
@@ -330,15 +216,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
validate_config(cfg)
normalize_config(cfg)
@@ -347,15 +224,55 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
return cfg
def do_train(config: Path = Path("examples/"), **kwargs):
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
train(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_train)
fire.Fire(do_cli)

View File

@@ -2,38 +2,41 @@
from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if "flash-attn" not in r]
reqs = [r for r in reqs if r and r[0] != "#"]
for r in reqs:
install_requires.append(r)
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif "flash-attn" not in line and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.1",
description="You know you're going to axolotl questions",
version="0.3.0",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"gptq": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"flash-attn": [
"flash-attn==2.0.8",
"flash-attn>=2.2.1",
],
"extras": [
"deepspeed",
],
"peft": [
"peft @ git+https://github.com/huggingface/peft.git",
],
},
)

View File

43
src/axolotl/common/cli.py Normal file
View File

@@ -0,0 +1,43 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
return model, tokenizer

View File

@@ -23,6 +23,7 @@ class ColorfulFormatter(Formatter):
}
def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
@@ -35,7 +36,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
},
},
"filters": {},

View File

@@ -2,7 +2,9 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -33,6 +35,9 @@ except ImportError:
)
LOG = logging.getLogger("axolotl")
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
@@ -44,6 +49,34 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
llama_model_forward
)
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask

View File

@@ -309,10 +309,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
)
def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":
source.pop(0)
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
@@ -321,6 +317,12 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
)
conv = self._conversation.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.system = source[0]["value"]
source.pop(0)
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try:

141
src/axolotl/train.py Normal file
View File

@@ -0,0 +1,141 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
# add src to the pythonpath so we don't need to pip install this
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
return model, tokenizer

View File

@@ -25,8 +25,10 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
gather_scalar_from_all_ranks,
get_world_size,
is_distributed,
is_main_process,
zero_first,
)
@@ -270,12 +272,16 @@ def bench_eval_callback_factory(trainer, tokenizer):
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
results = {}
if is_distributed() and not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
if is_distributed():
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
else:
gathered_bench_names = [local_bench_names]
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {"bench_loss": bench_loss}
results = {f"{bench_split}_bench_loss": bench_loss}
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
@@ -287,6 +293,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
bench_refs = []
bench_preds = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
@@ -294,15 +302,24 @@ def bench_eval_callback_factory(trainer, tokenizer):
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"bench_{bench_split}_accuracy_{bench_name}"
f"{bench_split}_bench_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
references=bench_refs, predictions=bench_preds
)["accuracy"]
trainer.log(results)
results = broadcast_dict(results)
for key, val in results.items():
metrics[key] = val
return BenchEvalCallback

View File

@@ -6,6 +6,7 @@ import os
import torch
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.models import load_model_config
LOG = logging.getLogger("axolotl")
@@ -69,6 +70,16 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
model_config = load_model_config(cfg)
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -86,6 +97,11 @@ def validate_config(cfg):
)
)
if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
@@ -97,9 +113,7 @@ def validate_config(cfg):
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
raise ValueError("cfg.load_4bit parameter has been deprecated")
if cfg.adapter == "qlora":
if cfg.merge_lora:
@@ -206,6 +220,15 @@ def validate_config(cfg):
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if cfg.save_steps % cfg.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -2,7 +2,6 @@
import functools
import hashlib
import logging
from hashlib import md5
from pathlib import Path
from typing import Tuple, Union
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5( # nosec
md5(
(
str(cfg.sequence_len)
+ "@"
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
)
+ "|"
+ tokenizer_name
).encode("utf-8")
).hexdigest()
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
@@ -374,7 +380,7 @@ def load_prepare_datasets(
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
md5( # nosec
md5(
(
str(cfg.sequence_len)
+ "@"
@@ -385,8 +391,8 @@ def load_prepare_datasets(
)
+ "|"
+ tokenizer_name
).encode("utf-8")
).hexdigest()
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
@@ -500,12 +506,8 @@ def load_prepare_datasets(
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = hashlib.md5(
to_hash_train.encode(), usedforsecurity=False
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
with zero_first(is_main_process()):
dataset = dataset.train_test_split(

View File

@@ -2,6 +2,7 @@
utility helpers for distributed checks
"""
import os
import pickle # nosec
from contextlib import contextmanager
import torch
@@ -74,6 +75,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process():
@@ -91,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
gathered_values.append(float(tensor.item()))
return gathered_values
return None
def broadcast_dict(vals: dict):
if not is_distributed():
return vals
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
dist.broadcast(data_size, 0)
if not is_main_process():
# resize
data_tensor = data_tensor.new_empty([data_size.item()])
dist.broadcast(data_tensor, 0)
if not is_main_process():
data_list = data_tensor.cpu().tolist()
data_byte = bytes(data_list[: data_size.item()])
vals = pickle.loads(data_byte) # nosec
return vals

View File

@@ -4,19 +4,19 @@
import logging
import math
import os
from pathlib import Path
from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig
from peft import PeftConfig, prepare_model_for_kbit_training
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
@@ -155,32 +155,17 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
try:
if cfg.gptq:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model,
)
replace_peft_model_with_int4_lora_model()
except Exception as err:
LOG.exception(err)
raise err
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
try:
from peft import prepare_model_for_kbit_training
except ImportError:
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
model_config = load_model_config(cfg)
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
else:
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
@@ -191,45 +176,7 @@ def load_model(
bnb_4bit_quant_type="nf4",
)
try:
if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
try:
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
snapshot_download_kwargs[
"ignore_patterns"
] = cfg.base_model_ignore_patterns
cache_model_path = Path(
snapshot_download(base_model, **snapshot_download_kwargs)
)
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
+ list(cache_model_path.glob("*.bin"))
)
if len(files) > 0:
model_path = str(files[0])
else:
LOG.warning(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
half=cfg.fp16,
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
is_v1_model=cfg.gptq_model_v1
if cfg.gptq_model_v1 is not None
else True,
)
load_in_8bit = False
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
config_kwargs = {}
@@ -275,15 +222,24 @@ def load_model(
# )
# model.train() # sets to train instead of eval mode
elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
@@ -359,11 +315,12 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
@@ -371,7 +328,7 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
@@ -385,22 +342,10 @@ def load_model(
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
if cfg.gptq:
# Scales to half
LOG.info("Fitting 4bit scales and zeros to half")
for _, module in model.named_modules():
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
type(module)
):
if hasattr(module, "is_v1_model") and module.is_v1_model:
module.zeros = module.zeros.half()
module.scales = module.scales.half()
module.bias = module.bias.half()
if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and (cfg.gptq or cfg.load_in_4bit)
and (cfg.load_in_4bit)
):
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see

View File

@@ -8,13 +8,13 @@ from termcolor import colored
LOG = logging.getLogger("axolotl")
def check_dataset_labels(dataset, tokenizer):
def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5):
check_example_labels(dataset[idx], tokenizer)
for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
def check_example_labels(example, tokenizer):
def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
@@ -29,8 +29,10 @@ def check_example_labels(example, tokenizer):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask}, {input_id})", "white"
colored_token = colored(decoded_input_token, color) + (
not text_only
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
)
colored_tokens.append(colored_token)

View File

@@ -33,6 +33,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
LOG = logging.getLogger("axolotl")
@@ -361,7 +362,7 @@ def add_position_ids(sample):
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
@contextmanager
@@ -375,14 +376,17 @@ def disable_datasets_caching():
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
with zero_first(is_main_process()):
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
if eval_dataset:
eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset
@@ -401,6 +405,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
@@ -504,23 +518,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing,
)
gradient_checkpointing_ratio = (
cfg.gradient_checkpointing_ratio
if cfg.gradient_checkpointing_ratio
else 1.0
)
apply_gradient_checkpointing(
model, checkpoint_ratio=gradient_checkpointing_ratio
)
else:
training_arguments_kwargs[
"gradient_checkpointing"
] = cfg.gradient_checkpointing
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
@@ -578,6 +576,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
if cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
@@ -594,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
@@ -630,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
@@ -703,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)
return trainer

64
tests/test_data.py Normal file
View File

@@ -0,0 +1,64 @@
"""
test module for the axolotl.utis.data module
"""
import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
self.max_tokens = 15 # set a small number for easy inspection
def test_encode_pretraining(self):
examples = {
"text": [
"Hello, world!",
"Nice to meet you.",
"lorem ipsum dolor sit amet.",
"Nice to meet you again!.",
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)
# Assert the length of input_ids and attention_mask is correct
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
# Assert EOS and PAD tokens are correctly added
# hello world! is 4 tokens
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
# second part, 5 tokens
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
if __name__ == "__main__":
unittest.main()

View File

@@ -328,6 +328,20 @@ class ValidationTest(unittest.TestCase):
for record in self._caplog.records
)
cfg = DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,