Compare commits

..

66 Commits

Author SHA1 Message Date
Wing Lian
f6721baf10 tweak to make it work when we have no explicit test split
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-11 22:40:21 -04:00
Wing Lian
33814cc94e make sure we eval for openorca 2023-07-02 17:59:10 -04:00
Wing Lian
50254a7ccc handle orca splits 2023-07-01 07:20:23 -04:00
Wing Lian
3a783c04e4 Merge pull request #247 from OpenAccess-AI-Collective/fix-apex-base
update pip install command for apex
2023-07-01 06:18:25 -04:00
Wing Lian
1e5014acec Merge pull request #255 from OpenAccess-AI-Collective/open-orca-prompts
open orca support
2023-07-01 01:11:23 -04:00
Wing Lian
a10da1caff 11.7.0 nvidia/cuda docker images are deprecated, move to 11.7.1
Some checks failed
ci-cd-base / build-base (<nil>, 117, 11.7.1, 3.9, 1.13.1) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.10, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (gptq, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-01 00:29:07 -04:00
Wing Lian
4066c78631 Merge pull request #246 from OpenAccess-AI-Collective/sys-prompts-instruct
add option for instruct w sys prompts
2023-07-01 00:27:29 -04:00
Wing Lian
78a1e1fa12 open orca support 2023-07-01 00:19:41 -04:00
NanoCode012
bc8a2e5547 Merge pull request #249 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix typing list in prompt tokenizer
2023-06-30 15:01:41 +09:00
NanoCode012
910ebe47f5 Merge pull request #252 from OpenAccess-AI-Collective/NanoCode012-readme-fix
Add cfg.push_to_hub_model_id to readme
2023-06-30 14:56:55 +09:00
NanoCode012
c146880a75 Update README.md 2023-06-30 11:33:53 +09:00
NanoCode012
77bdb7d144 Fix typing list 2023-06-29 14:29:55 +09:00
Wing Lian
530809fd74 update pip install command for apex 2023-06-28 22:36:28 -04:00
Wing Lian
924bbfddec add option for instruct w sys prompts 2023-06-28 22:27:17 -04:00
Wing Lian
f150c027e3 Merge pull request #224 from OpenAccess-AI-Collective/system-prompt-data
System prompt data
2023-06-27 17:57:43 -04:00
Wing Lian
5c39c006c9 Merge pull request #244 from OpenAccess-AI-Collective/push-to-hub
push intermediate model checkpoints to hub
2023-06-27 17:57:30 -04:00
Wing Lian
612aabd8c4 push intermediate model checkpoints to hub 2023-06-27 15:40:25 -04:00
Wing Lian
af05883f75 Merge pull request #243 from OpenAccess-AI-Collective/unprompted-instruct
skip the system prompt
2023-06-25 22:50:35 -04:00
Wing Lian
05ab9092e3 skip the system prompt 2023-06-25 22:40:50 -04:00
Wing Lian
7b57ed7618 pylint for duplicated code for system prompts 2023-06-25 22:28:07 -04:00
Wing Lian
3a38271276 add tests and supoort for loader for sys prompt data 2023-06-25 22:28:07 -04:00
Wing Lian
8d20e0a3d3 initial wip to get sys prompt from dataset 2023-06-25 22:28:07 -04:00
Wing Lian
de8ed229c3 Merge pull request #240 from OpenAccess-AI-Collective/tokenizer-fast
optionally define whether to use_fast tokenizer
2023-06-25 12:47:55 -04:00
Wing Lian
478d8c7b8e Merge pull request #241 from OpenAccess-AI-Collective/py3-pre-commit
better py3 support w pre-commit
2023-06-25 12:47:02 -04:00
Wing Lian
645c13592c better py3 support w pre-commit 2023-06-25 10:26:02 -04:00
Wing Lian
47d601fa23 optionally define whether to use_fast tokenizer 2023-06-25 10:19:49 -04:00
Wing Lian
756dfba97b Merge pull request #218 from OpenAccess-AI-Collective/no-fail-fast
don't fail fast
2023-06-23 15:42:54 -04:00
Wing Lian
91ab0592af Merge pull request #235 from msinha251/Fixing-data-readme 2023-06-23 13:52:01 -04:00
Mahesh Sinha
0aeb7c7802 Fixing Data Readme 2023-06-21 15:34:48 +02:00
Wing Lian
d35278aaf1 don't fail fast 2023-06-15 16:01:27 -04:00
Wing Lian
9492d4ebb7 Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperparams-cfg
support adamw and grad norm hyperparams
2023-06-15 12:20:55 -04:00
Wing Lian
ad5ca4f734 Additional test case per pr 2023-06-15 10:12:47 -04:00
Wing Lian
cb9d3af5c0 add validation and tests for adamw hyperparam 2023-06-15 09:39:42 -04:00
Wing Lian
c969f0a9dc add docs 2023-06-15 08:43:20 -04:00
Wing Lian
6d0ee4ba34 support adamw and grad norm hyperparams 2023-06-15 08:40:41 -04:00
Wing Lian
a81f52d575 Merge pull request #212 from OpenAccess-AI-Collective/doc-20230615-v1
add float16 docs and tweak typehints
2023-06-15 08:28:57 -04:00
Wing Lian
1925eaf1e6 Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
2023-06-15 08:13:43 -04:00
Wing Lian
1ab3bf3e67 fix test name 2023-06-15 02:09:33 -04:00
Wing Lian
d7635b7148 hint to what AMP means 2023-06-15 02:06:27 -04:00
Wing Lian
88e17ffc50 add float16 docs and tweak typehints 2023-06-15 02:05:31 -04:00
Wing Lian
baed440fa1 ingore duplicate code in tests 2023-06-15 02:03:53 -04:00
Wing Lian
7925ddce86 bugfix for potential off by one 2023-06-15 01:59:33 -04:00
Wing Lian
6f849809c5 Merge pull request #206 from MaciejKarasek/issue205
issue #205 bugfix
2023-06-14 14:23:38 -04:00
Wing Lian
c16644d05e Merge pull request #209 from sroecker/fix_redpajama_example_tokenizer
Use AutoTokenizer for redpajama example
2023-06-14 14:23:21 -04:00
Steffen Röcker
945c4191a3 Use AutoTokenizer for redpajama example 2023-06-14 20:09:26 +02:00
maciej.karasek
136522f9c9 style correction 2023-06-14 20:02:09 +02:00
maciej.karasek
556fe408b3 issue #205 bugfix 2023-06-14 16:59:57 +02:00
Wing Lian
16bb6276a5 Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum
add support for opimum bettertransformers
2023-06-14 07:50:15 -04:00
NanoCode012
06674a11f2 Merge pull request #202 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix sharegpt type in doc
2023-06-14 09:48:35 +09:00
NanoCode012
3513885f43 Fix sharegpt type 2023-06-14 01:10:58 +09:00
Wing Lian
4b43a66a0b update alpaca_chat prompts for instructions to explainn the conversation 2023-06-12 18:38:38 -04:00
Wing Lian
fd2c9814c9 Merge branch 'main' into flash-optimum 2023-06-12 13:12:15 -04:00
Wing Lian
c9a149f9e8 add check for attr 2023-06-11 10:11:17 -04:00
Wing Lian
958da70376 fix formatting 2023-06-10 15:28:08 -04:00
Wing Lian
759e8673ce Update scripts/finetune.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 14:25:21 -04:00
Wing Lian
0c6f928601 address PR feedback 2023-06-10 14:23:56 -04:00
Wing Lian
eea2731a5e add streaming dataset support for pretraining datasets 2023-06-10 14:23:56 -04:00
Wing Lian
1db46a9c72 linting fix 2023-06-10 14:23:56 -04:00
Wing Lian
ab5cd28acf more gpt-neox long ctx fixes 2023-06-10 14:23:55 -04:00
Wing Lian
1a82082e91 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-10 14:23:55 -04:00
Wing Lian
1210dc8fd5 more tweaks to do pre-training with bettertransformers 2023-06-10 14:23:55 -04:00
Wing Lian
488a67d75a experimental expansion of ctx len 2023-06-10 14:23:53 -04:00
Wing Lian
71a43f8479 add validation/warning for bettertransformers and torch version 2023-06-10 14:22:31 -04:00
Wing Lian
39619028a3 use pythia-12b, neox-20b is flaky 2023-06-10 14:22:30 -04:00
Wing Lian
8792199799 add flash attn context for efficient training and attempt setting model to train mode: 2023-06-10 14:22:30 -04:00
Wing Lian
1edc30c786 add support for opimum bettertransformers 2023-06-10 14:22:30 -04:00
28 changed files with 942 additions and 120 deletions

View File

@@ -12,6 +12,7 @@ jobs:
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted runs-on: self-hosted
strategy: strategy:
fail-fast: false
matrix: matrix:
include: include:
- cuda: "118" - cuda: "118"
@@ -25,7 +26,7 @@ jobs:
pytorch: 2.0.0 pytorch: 2.0.0
axolotl_extras: axolotl_extras:
- cuda: "117" - cuda: "117"
cuda_version: 11.7.0 cuda_version: 11.7.1
python_version: "3.9" python_version: "3.9"
pytorch: 1.13.1 pytorch: 1.13.1
axolotl_extras: axolotl_extras:

View File

@@ -11,6 +11,7 @@ jobs:
if: github.repository_owner == 'OpenAccess-AI-Collective' if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
strategy: strategy:
fail-fast: false
matrix: matrix:
include: include:
- cuda: cu118 - cuda: cu118
@@ -29,7 +30,7 @@ jobs:
pytorch: 2.0.0 pytorch: 2.0.0
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117 - cuda: cu117
cuda_version: 11.7.0 cuda_version: 11.7.1
python_version: "3.9" python_version: "3.9"
pytorch: 1.13.1 pytorch: 1.13.1
axolotl_extras: axolotl_extras:
@@ -84,7 +85,7 @@ jobs:
pytorch: 2.0.0 pytorch: 2.0.0
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117 - cuda: cu117
cuda_version: 11.7.0 cuda_version: 11.7.1
python_version: "3.9" python_version: "3.9"
pytorch: 1.13.1 pytorch: 1.13.1
axolotl_extras: axolotl_extras:

View File

@@ -7,6 +7,7 @@ jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false
matrix: matrix:
python_version: ["3.9", "3.10"] python_version: ["3.9", "3.10"]
timeout-minutes: 10 timeout-minutes: 10

View File

@@ -1,5 +1,5 @@
default_language_version: default_language_version:
python: python3.9 python: python3
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks

View File

@@ -138,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt`: conversations - `sharegpt:chat`: conversations
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -195,6 +195,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"message_1": "...", "message_2": "..."} {"message_1": "...", "message_2": "..."}
``` ```
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
```json
{"system_prompt": "...", "question": "...", "response": "..."}
```
- `context_qa`: in context question answering from an article - `context_qa`: in context question answering from an article
```json ```json
{"article": "...", "question": "...", "answer": "..."} {"article": "...", "question": "...", "answer": "..."}
@@ -264,6 +268,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
bf16: true # require >=ampere bf16: true # require >=ampere
fp16: true fp16: true
tf32: true # require >=ampere tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
``` ```
Note: Repo does not do 4-bit quantization. Note: Repo does not do 4-bit quantization.
@@ -300,6 +306,8 @@ model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
# Trust remote code for untrusted source # Trust remote code for untrusted source
trust_remote_code: trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# whether you are training a 4-bit GPTQ quantized model # whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -332,6 +340,8 @@ datasets:
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub # push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub
push_to_hub_model_id: # repo path
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
@@ -420,7 +430,15 @@ log_sweep_max_lr:
optimizer: optimizer:
# specify weight decay # specify weight decay
weight_decay: weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# whether to bettertransformers
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers: # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
@@ -520,6 +538,12 @@ Add below flag to train command above
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
```
## Common Errors 🧰 ## Common Errors 🧰
> Cuda out of memory > Cuda out of memory

View File

@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
## Convert the JSON data files to JSONL. ## Convert the JSON data files to JSONL.
```shell ```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
``` ```
--- ---

View File

@@ -77,7 +77,7 @@ FROM base-builder
RUN python3 -m pip uninstall -y apex RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners # `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check . RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -0,0 +1,49 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -1,7 +1,7 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: GPTNeoXTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
load_in_8bit: false load_in_8bit: false
datasets: datasets:

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb wandb
einops einops
xformers xformers
optimum
# qlora things # qlora things
bert-score==0.3.13 bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0

View File

@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
@@ -217,9 +218,20 @@ def train(
if ( if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these ): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( if not cfg.pretraining_dataset:
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH train_dataset, eval_dataset = load_prepare_datasets(
) tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
@@ -285,12 +297,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -313,13 +328,21 @@ def train(
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0: if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
buffer_len = 0 buffer_len = 0
if example: if example:
# FIXME
# just going to drop data points that are too long # just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length: if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"] input_ids = example["input_ids"]

View File

@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy, InstructionPromptTokenizingStrategy,
) )
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg):
@@ -20,11 +20,38 @@ def load(tokenizer, cfg):
class AlpacaConcisePrompter(AlpacaPrompter): class AlpacaConcisePrompter(AlpacaPrompter):
""" """
Alpaca Prompter extending the system prompt to ask for concise answers Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
""" """
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n" system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n" system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()
class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""
system_prompt = ""
system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
@@ -64,7 +91,7 @@ def load_concise(tokenizer, cfg):
def load_qa(tokenizer, cfg): def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy( return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaChatPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -73,7 +100,16 @@ def load_qa(tokenizer, cfg):
def load_camel_ai(tokenizer, cfg): def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy( return CamelAIPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -1,7 +1,7 @@
"""Module loading the AlpacaInstructPromptTokenizingStrategy class""" """Module loading the AlpacaInstructPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg):
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -0,0 +1,120 @@
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
prompt["system"],
)
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
(
instruction,
input, # pylint: disable=redefined-builtin
response,
system,
) = self.parse_instruction_fields(prompt)
user_prompt = next(
iter(
self.prompter.build_prompt_w_system(
system,
instruction,
input,
)
)
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_prompt
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Tokenizing strategy for OpenOrca datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["question"],
"",
prompt["response"],
prompt["system_prompt"],
)
def load(tokenizer, cfg):
return load_chat(tokenizer, cfg)
def load_instruct(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_chat(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts. Tokenizing strategy for instruction-based prompts.
""" """
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
@@ -96,25 +98,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
response, response,
) = self.parse_instruction_fields(prompt) ) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response) user_prompt = next(
tokenized_full_prompt = self._tokenize(full_prompt) iter(
if not self.train_on_inputs: self.prompter.build_prompt(
user_prompt = next( instruction,
iter( input,
self.prompter.build_prompt(
instruction,
input,
)
) )
) )
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) )
user_prompt_len = len(tokenized_user_prompt["input_ids"]) tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_prompt["labels"] = [-100] * user_prompt_len
-100 tokenized_res_prompt = self._tokenize(
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_full_prompt return tokenized_prompt
def _build_full_prompt( def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin self, instruction, input, response # pylint: disable=redefined-builtin
@@ -436,7 +440,7 @@ def parse_tokenized_to_result(
result: Dict[str, List[int]], result: Dict[str, List[int]],
current_len: int, current_len: int,
res: Dict[str, List[int]], res: Dict[str, List[int]],
labels: list[int], labels: List[int],
pad_token_id: Union[int, None] = None, pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]: ) -> Tuple[Dict[str, List[int]], int]:
""" """

View File

@@ -24,6 +24,8 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
def match_prompt_style(self): def match_prompt_style(self):
if self.prompt_style == PromptStyle.INSTRUCT.value: if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = ( self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.system_prompt self.turn_no_input_format = (
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" "### Instruction:\n{instruction}\n\n### Response:\n"
) )
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = ( self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
def build_prompt( def build_prompt(
self, self,
@@ -59,16 +51,17 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
res = self.prompt_input.format(instruction=instruction, input=input) res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else: else:
res = self.prompt_no_input.format(instruction=instruction) res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
""" """
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
""" """
system_prompt = ( system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning." "Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
) )
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise Prompter for multiple choice concise
""" """
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter): class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR Prompter for summarize TLDR
""" """
prompt_no_input = ( system_prompt = ""
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" system_no_input_prompt = ""
)
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter: class CompletionPrompter:
@@ -128,9 +132,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
yield instruction yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter): class GPTeacherPrompter(AlpacaPrompter):
""" """
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}" res = f"{res}{label}"
yield res yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
"""Different separator style.""" """Different separator style."""
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
sep2=" ", sep2=" ",
) )
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
def build_prompt(self, source) -> Generator[str, None, None]: def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided # ignore the system prompt if provided
if source[0]["from"] == "system": if source[0]["from"] == "system":

View File

@@ -2,13 +2,14 @@
import os import os
from optimum.bettertransformer import BetterTransformer
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
TrainerState, TrainerState,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(peft_model_path)
return control return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -1,10 +1,11 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -36,7 +37,7 @@ from axolotl.prompters import (
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path split, tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict: ) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
@@ -48,6 +49,8 @@ def load_tokenized_prepared_datasets(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
) )
+ "|" + "|"
+ split
+ "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
@@ -65,7 +68,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
@@ -133,8 +136,8 @@ def load_tokenized_prepared_datasets(
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
if "train" in ds: if split in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)[split].shard(
num_shards=d.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
@@ -143,8 +146,8 @@ def load_tokenized_prepared_datasets(
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds: if split in ds:
ds = ds["train"] ds = ds[split]
if ds_strategy := load(d.type, tokenizer, cfg): if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
@@ -318,7 +321,6 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
@@ -338,28 +340,37 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset_train = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
dataset_test = load_tokenized_prepared_datasets(
"test", tokenizer, cfg, default_dataset_prepared_path
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.seed: if cfg.seed:
dataset = dataset.shuffle(seed=cfg.seed) dataset = dataset.shuffle(seed=cfg.seed)
constant_len_dataset = ConstantLengthDataset( constant_len_dataset_train = ConstantLengthDataset(
tokenizer, tokenizer,
[dataset], [dataset["train"]],
seq_length=max_packed_sequence_len,
)
constant_len_dataset_test = ConstantLengthDataset(
tokenizer,
[dataset["test"]],
seq_length=max_packed_sequence_len, seq_length=max_packed_sequence_len,
) )
logging.info( logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}" f"packing master dataset to len: {cfg.max_packed_sequence_len}"
) )
dataset = Dataset.from_list(list(constant_len_dataset)) dataset_train = Dataset.from_list(list(constant_len_dataset_train))
dataset_test = Dataset.from_list(list(constant_len_dataset_test))
# filter out bad data # filter out bad data
dataset = Dataset.from_list( dataset_train = Dataset.from_list(
[ [
d d
for d in dataset for d in dataset_train
if len(d["input_ids"]) < cfg.sequence_len if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"]) and len(d["input_ids"]) == len(d["attention_mask"])
@@ -367,6 +378,19 @@ def load_prepare_datasets(
] ]
) )
# filter out bad data
dataset_test = Dataset.from_list(
[
d
for d in dataset_test
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.local_rank == 0: if cfg.local_rank == 0:
logging.info( logging.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}" f"Saving packed prepared dataset to disk... {prepared_ds_path}"
@@ -381,9 +405,14 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
# dataset_train = load_tokenized_prepared_datasets(
dataset = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
# dataset_test = load_tokenized_prepared_datasets(
# "test", tokenizer, cfg, default_dataset_prepared_path
# )
# dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info( logging.info(
@@ -394,8 +423,130 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx, index=cfg.dataset_shard_idx,
) )
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) if cfg.val_set_size:
train_dataset = dataset["train"] dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
eval_dataset = dataset["test"] train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif "train" in dataset:
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples):
res = tokenizer(
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
logging.debug(len(ret["input_ids"]))
return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
return dataset

View File

@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from transformers import PreTrainedModel # noqa: F401 from optimum.bettertransformer import BetterTransformer
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
LlamaConfig, LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -32,15 +34,20 @@ def load_tokenizer(
tokenizer_type, tokenizer_type,
cfg, cfg,
): ):
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if tokenizer_type: if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained( tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
) )
else: else:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
) )
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -70,7 +77,7 @@ def load_tokenizer(
def load_model( def load_model(
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
): ):
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model from a base model and a model type.
""" """
@@ -121,9 +128,9 @@ def load_model(
logging.info("patching with xpos rope") logging.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope() replace_llama_rope_with_xpos_rope()
if cfg.bf16: if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16 torch_dtype = torch.float16
else: else:
torch_dtype = torch.float32 torch_dtype = torch.float32
@@ -251,11 +258,16 @@ def load_model(
) )
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len: if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
):
config.max_seq_len = cfg.sequence_len config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}") logging.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len config.max_sequence_length = cfg.sequence_len
@@ -278,6 +290,7 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map, device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
@@ -287,6 +300,16 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
):
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -332,6 +355,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates") logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
logging.info(" ".join(colored_tokens)) logging.info(" ".join(colored_tokens))
logging.info("\n\n\n") logging.info("\n\n\n")
return " ".join(colored_tokens)

View File

@@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -112,6 +115,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one # TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json" training_arguments_kwargs["deepspeed"] = "./ds_config.json"
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.push_to_hub_model_id:
training_arguments_kwargs["push_to_hub_model_id"] = cfg.push_to_hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_args = transformers.TrainingArguments( training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
@@ -121,9 +137,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
eval_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", evaluation_strategy="steps",
save_strategy="steps" if cfg.save_steps else "epoch", save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, eval_steps=cfg.eval_steps,
save_steps=cfg.save_steps, save_steps=cfg.save_steps,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
save_total_limit=3, save_total_limit=3,
@@ -228,6 +244,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0 ]: # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True,
} }

View File

@@ -2,6 +2,8 @@
import logging import logging
import torch
def validate_config(cfg): def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -62,7 +64,42 @@ def validate_config(cfg):
) and cfg.gradient_checkpointing: ) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models") raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
logging.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16 # no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -6,8 +6,16 @@ from pathlib import Path
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompters import ShareGPTPrompter from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
logging.basicConfig(level="INFO") logging.basicConfig(level="INFO")
@@ -29,7 +37,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) )
def test_sharegpt_integration(self): def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open( with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin: ) as fin:
@@ -53,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields]) self.assertEqual(example[fields], tokenized_conversation[fields])
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
"output": "world!",
}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(3186)
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(6324)
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"system": "use cot",
"instruction": "hello!",
"output": "Hi! How can I help?",
}
example = strat.tokenize_prompt(sample)
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
assert example["input_ids"][3] == 11889 # USER
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -2,7 +2,13 @@
import unittest import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
from axolotl.prompters import (
AlpacaPrompter,
MultipleChoiceExplainPrompter,
PromptStyle,
UnpromptedPrompter,
)
class AlpacaPrompterTest(unittest.TestCase): class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res assert "### Response:" not in res
assert "USER:" in res assert "USER:" in res
assert "ASSISTANT:" in res assert "ASSISTANT:" in res
def test_system_prompt(self):
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt_w_system(
"use cot", "tell me a joke about the following", "alpacas"
)
)
assert "use cot" in res
assert res.startswith("use cot")
assert "### Instruction:" not in res
assert "### Input:" not in res
assert "alpacas" in res
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
class UnpromptedPrompterTest(unittest.TestCase):
"""
Test class for UnpromptedPrompter with no system prompts
"""
def test_prompt_style_w_none(self):
prompter = UnpromptedPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_instruct(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_chat(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "USER:" in res
assert "tell me a joke" in res
assert res.startswith("USER:")
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
"""
Test class for MultipleChoiceExplainPrompter
"""
def test_prompt_style_w_chat(self):
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
assert "USER:" in res
assert "choose one" in res
assert "Choose the answer that best answers the question." in res
assert "- A\n- B\n- C" in res

31
tests/test_tokenizers.py Normal file
View File

@@ -0,0 +1,31 @@
"""
Test cases for the tokenizer loading
"""
import unittest
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
def test_default_use_fast(self):
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__
if __name__ == "__main__":
unittest.main()

View File

@@ -212,3 +212,104 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=regex_exp): with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg) validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{
"flash_optimum": True,
"adapter": "lora",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
"fp16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"flash_optimum": True,
"bf16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adamw_epsilon": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adafactor",
"adamw_beta1": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adamw_beta1": 0.0001,
"adamw_beta2": 0.0001,
"adamw_epsilon": 0.0001,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"optimizer": "adafactor",
}
)
validate_config(cfg)