Merge branch 'main' into quadratic-warmup

This commit is contained in:
Wing Lian
2023-07-10 12:42:12 -04:00
committed by GitHub
29 changed files with 958 additions and 130 deletions

View File

@@ -12,6 +12,7 @@ jobs:
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted runs-on: self-hosted
strategy: strategy:
fail-fast: false
matrix: matrix:
include: include:
- cuda: "118" - cuda: "118"
@@ -25,7 +26,7 @@ jobs:
pytorch: 2.0.0 pytorch: 2.0.0
axolotl_extras: axolotl_extras:
- cuda: "117" - cuda: "117"
cuda_version: 11.7.0 cuda_version: 11.7.1
python_version: "3.9" python_version: "3.9"
pytorch: 1.13.1 pytorch: 1.13.1
axolotl_extras: axolotl_extras:

View File

@@ -11,6 +11,7 @@ jobs:
if: github.repository_owner == 'OpenAccess-AI-Collective' if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
strategy: strategy:
fail-fast: false
matrix: matrix:
include: include:
- cuda: cu118 - cuda: cu118
@@ -29,7 +30,7 @@ jobs:
pytorch: 2.0.0 pytorch: 2.0.0
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117 - cuda: cu117
cuda_version: 11.7.0 cuda_version: 11.7.1
python_version: "3.9" python_version: "3.9"
pytorch: 1.13.1 pytorch: 1.13.1
axolotl_extras: axolotl_extras:
@@ -84,7 +85,7 @@ jobs:
pytorch: 2.0.0 pytorch: 2.0.0
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117 - cuda: cu117
cuda_version: 11.7.0 cuda_version: 11.7.1
python_version: "3.9" python_version: "3.9"
pytorch: 1.13.1 pytorch: 1.13.1
axolotl_extras: axolotl_extras:

View File

@@ -7,6 +7,7 @@ jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false
matrix: matrix:
python_version: ["3.9", "3.10"] python_version: ["3.9", "3.10"]
timeout-minutes: 10 timeout-minutes: 10

View File

@@ -1,5 +1,5 @@
default_language_version: default_language_version:
python: python3.9 python: python3
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks

View File

@@ -138,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt`: conversations - `sharegpt:chat`: conversations
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
@@ -195,6 +195,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"message_1": "...", "message_2": "..."} {"message_1": "...", "message_2": "..."}
``` ```
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
```json
{"system_prompt": "...", "question": "...", "response": "..."}
```
- `context_qa`: in context question answering from an article - `context_qa`: in context question answering from an article
```json ```json
{"article": "...", "question": "...", "answer": "..."} {"article": "...", "question": "...", "answer": "..."}
@@ -233,7 +237,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example. 1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type. 2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
Optionally, download some datasets, see [data/README.md](data/README.md) Optionally, download some datasets, see [data/README.md](data/README.md)
@@ -251,10 +255,18 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
- dataset - dataset
```yaml ```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets: datasets:
- path: vicgalle/alpaca-gpt4 # local or huggingface repo - path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
# local
datasets:
- path: json
data_files: data.jsonl # or json
type: alpaca # format from earlier type: alpaca # format from earlier
sequence_len: 2048 # max token length / prompt
``` ```
- loading - loading
@@ -264,6 +276,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
bf16: true # require >=ampere bf16: true # require >=ampere
fp16: true fp16: true
tf32: true # require >=ampere tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
``` ```
Note: Repo does not do 4-bit quantization. Note: Repo does not do 4-bit quantization.
@@ -300,6 +314,8 @@ model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
# Trust remote code for untrusted source # Trust remote code for untrusted source
trust_remote_code: trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# whether you are training a 4-bit GPTQ quantized model # whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -320,10 +336,10 @@ tf32: true # require >=ampere
# a list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# this can be either a hf dataset, or relative path # hf dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format OR format:prompt_style (chat/instruct) type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
data_files: # path to source data files data_files: # path to source data files
shards: # number of shards to split data into shards: # number of shards to split data into
@@ -332,6 +348,8 @@ datasets:
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub # push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
@@ -420,7 +438,15 @@ log_sweep_max_lr:
optimizer: optimizer:
# specify weight decay # specify weight decay
weight_decay: weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# whether to bettertransformers
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers: # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
@@ -500,16 +526,16 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
--inference --lora_model_dir ./completed-model --inference --lora_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
--inference --base_model ./completed-model --inference --base_model="./completed-model"
``` ```
- Full weights finetune w/ a prompt from a text file: - Full weights finetune w/ a prompt from a text file:
```bash ```bash
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \ cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
--base_model ./completed-model --inference --prompter=None --load_in_8bit=True --base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
``` ```
### Merge LORA to base ### Merge LORA to base
@@ -520,6 +546,12 @@ Add below flag to train command above
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
```
## Common Errors 🧰 ## Common Errors 🧰
> Cuda out of memory > Cuda out of memory
@@ -552,6 +584,16 @@ Building something cool with Axolotl? Consider adding a badge to your model card
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl) [<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
## Community Showcase
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
PocketDoc Labs
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
## Contributing 🤝 ## Contributing 🤝
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new). Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).

View File

@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
## Convert the JSON data files to JSONL. ## Convert the JSON data files to JSONL.
```shell ```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
``` ```
--- ---

View File

@@ -77,7 +77,7 @@ FROM base-builder
RUN python3 -m pip uninstall -y apex RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners # `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check . RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
@@ -97,4 +97,4 @@ RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo RUN git lfs install --skip-repo
RUN pip3 install awscli && \ RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -26,17 +26,18 @@ wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./openllama-out output_dir: ./openllama-out
batch_size: 16 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 1
num_epochs: 3 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
float16: true
bf16: false bf16: false
fp16: true fp16: false
tf32: false tf32: false
gradient_checkpointing: true gradient_checkpointing: true
early_stopping_patience: early_stopping_patience:
@@ -52,7 +53,7 @@ eval_steps: 50
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -0,0 +1,49 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -1,7 +1,7 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: GPTNeoXTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
load_in_8bit: false load_in_8bit: false
datasets: datasets:

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb wandb
einops einops
xformers xformers
optimum
# qlora things # qlora things
bert-score==0.3.13 bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0

View File

@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
@@ -63,7 +64,7 @@ def get_multi_line_input() -> Optional[str]:
return instruction return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items(): for token, symbol in default_tokens.items():
@@ -217,9 +218,20 @@ def train(
if ( if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these ): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( if not cfg.pretraining_dataset:
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH train_dataset, eval_dataset = load_prepare_datasets(
) tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
@@ -257,13 +269,13 @@ def train(
if cfg.inference: if cfg.inference:
logging.info("calling do_inference function") logging.info("calling do_inference function")
inf_kwargs: Dict[str, Any] = {} prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs: if "prompter" in kwargs:
if kwargs["prompter"] == "None": if kwargs["prompter"] == "None":
inf_kwargs["prompter"] = None prompter = None
else: else:
inf_kwargs["prompter"] = kwargs["prompter"] prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, **inf_kwargs) do_inference(cfg, model, tokenizer, prompter=prompter)
return return
if "shard" in kwargs: if "shard" in kwargs:
@@ -285,12 +297,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -313,13 +328,21 @@ def train(
if not Path(cfg.output_dir).is_dir(): if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True) os.makedirs(cfg.output_dir, exist_ok=True)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0: if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
buffer_len = 0 buffer_len = 0
if example: if example:
# FIXME
# just going to drop data points that are too long # just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length: if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"] input_ids = example["input_ids"]

View File

@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy, InstructionPromptTokenizingStrategy,
) )
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg):
@@ -20,11 +20,38 @@ def load(tokenizer, cfg):
class AlpacaConcisePrompter(AlpacaPrompter): class AlpacaConcisePrompter(AlpacaPrompter):
""" """
Alpaca Prompter extending the system prompt to ask for concise answers Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
""" """
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n" system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n" system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()
class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""
system_prompt = ""
system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
@@ -64,7 +91,7 @@ def load_concise(tokenizer, cfg):
def load_qa(tokenizer, cfg): def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy( return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaChatPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -73,7 +100,16 @@ def load_qa(tokenizer, cfg):
def load_camel_ai(tokenizer, cfg): def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy( return CamelAIPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -1,7 +1,7 @@
"""Module loading the AlpacaInstructPromptTokenizingStrategy class""" """Module loading the AlpacaInstructPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg):
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
def load_no_prompt(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -0,0 +1,120 @@
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
prompt["system"],
)
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
(
instruction,
input, # pylint: disable=redefined-builtin
response,
system,
) = self.parse_instruction_fields(prompt)
user_prompt = next(
iter(
self.prompter.build_prompt_w_system(
system,
instruction,
input,
)
)
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_prompt
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Tokenizing strategy for OpenOrca datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["question"],
"",
prompt["response"],
prompt["system_prompt"],
)
def load(tokenizer, cfg):
return load_chat(tokenizer, cfg)
def load_instruct(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_chat(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts. Tokenizing strategy for instruction-based prompts.
""" """
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
@@ -96,25 +98,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
response, response,
) = self.parse_instruction_fields(prompt) ) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response) user_prompt = next(
tokenized_full_prompt = self._tokenize(full_prompt) iter(
if not self.train_on_inputs: self.prompter.build_prompt(
user_prompt = next( instruction,
iter( input,
self.prompter.build_prompt(
instruction,
input,
)
) )
) )
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) )
user_prompt_len = len(tokenized_user_prompt["input_ids"]) tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_prompt["labels"] = [-100] * user_prompt_len
-100 tokenized_res_prompt = self._tokenize(
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_full_prompt return tokenized_prompt
def _build_full_prompt( def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin self, instruction, input, response # pylint: disable=redefined-builtin
@@ -436,7 +440,7 @@ def parse_tokenized_to_result(
result: Dict[str, List[int]], result: Dict[str, List[int]],
current_len: int, current_len: int,
res: Dict[str, List[int]], res: Dict[str, List[int]],
labels: list[int], labels: List[int],
pad_token_id: Union[int, None] = None, pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]: ) -> Tuple[Dict[str, List[int]], int]:
""" """

View File

@@ -24,6 +24,8 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
def match_prompt_style(self): def match_prompt_style(self):
if self.prompt_style == PromptStyle.INSTRUCT.value: if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = ( self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.system_prompt self.turn_no_input_format = (
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" "### Instruction:\n{instruction}\n\n### Response:\n"
) )
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = ( self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
def build_prompt( def build_prompt(
self, self,
@@ -59,16 +51,17 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
res = self.prompt_input.format(instruction=instruction, input=input) res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else: else:
res = self.prompt_no_input.format(instruction=instruction) res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
""" """
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
""" """
system_prompt = ( system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning." "Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
) )
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise Prompter for multiple choice concise
""" """
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter): class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR Prompter for summarize TLDR
""" """
prompt_no_input = ( system_prompt = ""
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" system_no_input_prompt = ""
)
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter: class CompletionPrompter:
@@ -128,9 +132,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
yield instruction yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter): class GPTeacherPrompter(AlpacaPrompter):
""" """
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}" res = f"{res}{label}"
yield res yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
"""Different separator style.""" """Different separator style."""
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
sep2=" ", sep2=" ",
) )
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
def build_prompt(self, source) -> Generator[str, None, None]: def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided # ignore the system prompt if provided
if source[0]["from"] == "system": if source[0]["from"] == "system":

View File

@@ -2,13 +2,14 @@
import os import os
from optimum.bettertransformer import BetterTransformer
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
TrainerState, TrainerState,
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path) kwargs["model"].save_pretrained(peft_model_path)
return control return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -1,10 +1,11 @@
"""Module containing data utilities""" """Module containing data utilities"""
import functools
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -101,13 +102,26 @@ def load_tokenized_prepared_datasets(
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
if Path(d.path).exists(): local_path = Path(d.path)
ds = load_dataset( if local_path.exists():
"json", if local_path.is_dir():
data_files=d.path, ds = load_dataset(
streaming=False, d.path,
split=None, data_files=d.data_files,
) streaming=False,
split=None,
)
elif local_path.is_file():
ds = load_dataset(
"json",
data_files=d.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub: elif ds_from_hub:
if d.data_files: if d.data_files:
ds = load_dataset( ds = load_dataset(
@@ -394,8 +408,127 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx, index=cfg.dataset_shard_idx,
) )
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) if cfg.val_set_size:
train_dataset = dataset["train"] dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
eval_dataset = dataset["test"] train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples):
res = tokenizer(
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
logging.debug(len(ret["input_ids"]))
return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
return dataset

View File

@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from transformers import PreTrainedModel # noqa: F401 from optimum.bettertransformer import BetterTransformer
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
LlamaConfig, LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -32,15 +34,20 @@ def load_tokenizer(
tokenizer_type, tokenizer_type,
cfg, cfg,
): ):
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if tokenizer_type: if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained( tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
) )
else: else:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
) )
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -70,7 +77,7 @@ def load_tokenizer(
def load_model( def load_model(
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora" base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
): ):
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model from a base model and a model type.
""" """
@@ -121,9 +128,9 @@ def load_model(
logging.info("patching with xpos rope") logging.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope() replace_llama_rope_with_xpos_rope()
if cfg.bf16: if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16: elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16 torch_dtype = torch.float16
else: else:
torch_dtype = torch.float32 torch_dtype = torch.float32
@@ -195,7 +202,7 @@ def load_model(
else True, else True,
) )
load_in_8bit = False load_in_8bit = False
elif cfg.is_llama_derived_model: elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained(base_model_config) config = LlamaConfig.from_pretrained(base_model_config)
@@ -234,7 +241,7 @@ def load_model(
# device=cfg.device, # device=cfg.device,
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type: elif model_type and not cfg.trust_remote_code:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
@@ -251,11 +258,16 @@ def load_model(
) )
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len: if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
):
config.max_seq_len = cfg.sequence_len config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}") logging.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len config.max_sequence_length = cfg.sequence_len
@@ -278,6 +290,7 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map, device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
@@ -287,6 +300,16 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
):
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and ( if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -332,6 +355,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates") logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
logging.info(" ".join(colored_tokens)) logging.info(" ".join(colored_tokens))
logging.info("\n\n\n") logging.info("\n\n\n")
return " ".join(colored_tokens)

View File

@@ -17,7 +17,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
InterpolatingLogScheduler, InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup,
@@ -166,6 +169,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one # TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json" training_arguments_kwargs["deepspeed"] = "./ds_config.json"
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_args = AxolotlTrainingArguments( training_args = AxolotlTrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
@@ -282,6 +298,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0 ]: # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True,
} }

View File

@@ -2,6 +2,8 @@
import logging import logging
import torch
def validate_config(cfg): def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -62,7 +64,47 @@ def validate_config(cfg):
) and cfg.gradient_checkpointing: ) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models") raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
logging.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
if cfg.push_to_hub_model_id:
raise ValueError(
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16 # no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -6,8 +6,16 @@ from pathlib import Path
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompters import ShareGPTPrompter from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
SystemDataPrompter,
)
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
logging.basicConfig(level="INFO") logging.basicConfig(level="INFO")
@@ -29,7 +37,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) )
def test_sharegpt_integration(self): def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open( with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin: ) as fin:
@@ -53,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields]) self.assertEqual(example[fields], tokenized_conversation[fields])
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
"output": "world!",
}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(3186)
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(6324)
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"system": "use cot",
"instruction": "hello!",
"output": "Hi! How can I help?",
}
example = strat.tokenize_prompt(sample)
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
assert example["input_ids"][3] == 11889 # USER
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -2,7 +2,13 @@
import unittest import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
from axolotl.prompters import (
AlpacaPrompter,
MultipleChoiceExplainPrompter,
PromptStyle,
UnpromptedPrompter,
)
class AlpacaPrompterTest(unittest.TestCase): class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res assert "### Response:" not in res
assert "USER:" in res assert "USER:" in res
assert "ASSISTANT:" in res assert "ASSISTANT:" in res
def test_system_prompt(self):
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt_w_system(
"use cot", "tell me a joke about the following", "alpacas"
)
)
assert "use cot" in res
assert res.startswith("use cot")
assert "### Instruction:" not in res
assert "### Input:" not in res
assert "alpacas" in res
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
class UnpromptedPrompterTest(unittest.TestCase):
"""
Test class for UnpromptedPrompter with no system prompts
"""
def test_prompt_style_w_none(self):
prompter = UnpromptedPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_instruct(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_chat(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "USER:" in res
assert "tell me a joke" in res
assert res.startswith("USER:")
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
"""
Test class for MultipleChoiceExplainPrompter
"""
def test_prompt_style_w_chat(self):
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
assert "USER:" in res
assert "choose one" in res
assert "Choose the answer that best answers the question." in res
assert "- A\n- B\n- C" in res

31
tests/test_tokenizers.py Normal file
View File

@@ -0,0 +1,31 @@
"""
Test cases for the tokenizer loading
"""
import unittest
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
def test_default_use_fast(self):
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" in tokenizer.__class__.__name__
def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
assert "Fast" not in tokenizer.__class__.__name__
if __name__ == "__main__":
unittest.main()

View File

@@ -212,3 +212,104 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=regex_exp): with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg) validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{
"flash_optimum": True,
"adapter": "lora",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
"fp16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"flash_optimum": True,
"bf16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adam_epsilon": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adafactor",
"adam_beta1": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"adam_epsilon": 0.0001,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"optimizer": "adafactor",
}
)
validate_config(cfg)