Compare commits

...

41 Commits

Author SHA1 Message Date
Wing Lian
e91fed495a better handling for tokenizers like flan that don't have a bos token
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-06-23 15:47:40 -04:00
Wing Lian
756dfba97b Merge pull request #218 from OpenAccess-AI-Collective/no-fail-fast
don't fail fast
2023-06-23 15:42:54 -04:00
Wing Lian
91ab0592af Merge pull request #235 from msinha251/Fixing-data-readme 2023-06-23 13:52:01 -04:00
Mahesh Sinha
0aeb7c7802 Fixing Data Readme 2023-06-21 15:34:48 +02:00
Wing Lian
d35278aaf1 don't fail fast 2023-06-15 16:01:27 -04:00
Wing Lian
9492d4ebb7 Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperparams-cfg
support adamw and grad norm hyperparams
2023-06-15 12:20:55 -04:00
Wing Lian
ad5ca4f734 Additional test case per pr 2023-06-15 10:12:47 -04:00
Wing Lian
cb9d3af5c0 add validation and tests for adamw hyperparam 2023-06-15 09:39:42 -04:00
Wing Lian
c969f0a9dc add docs 2023-06-15 08:43:20 -04:00
Wing Lian
6d0ee4ba34 support adamw and grad norm hyperparams 2023-06-15 08:40:41 -04:00
Wing Lian
a81f52d575 Merge pull request #212 from OpenAccess-AI-Collective/doc-20230615-v1
add float16 docs and tweak typehints
2023-06-15 08:28:57 -04:00
Wing Lian
1925eaf1e6 Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
2023-06-15 08:13:43 -04:00
Wing Lian
1ab3bf3e67 fix test name 2023-06-15 02:09:33 -04:00
Wing Lian
d7635b7148 hint to what AMP means 2023-06-15 02:06:27 -04:00
Wing Lian
88e17ffc50 add float16 docs and tweak typehints 2023-06-15 02:05:31 -04:00
Wing Lian
baed440fa1 ingore duplicate code in tests 2023-06-15 02:03:53 -04:00
Wing Lian
7925ddce86 bugfix for potential off by one 2023-06-15 01:59:33 -04:00
Wing Lian
6f849809c5 Merge pull request #206 from MaciejKarasek/issue205
issue #205 bugfix
2023-06-14 14:23:38 -04:00
Wing Lian
c16644d05e Merge pull request #209 from sroecker/fix_redpajama_example_tokenizer
Use AutoTokenizer for redpajama example
2023-06-14 14:23:21 -04:00
Steffen Röcker
945c4191a3 Use AutoTokenizer for redpajama example 2023-06-14 20:09:26 +02:00
maciej.karasek
136522f9c9 style correction 2023-06-14 20:02:09 +02:00
maciej.karasek
556fe408b3 issue #205 bugfix 2023-06-14 16:59:57 +02:00
Wing Lian
16bb6276a5 Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum
add support for opimum bettertransformers
2023-06-14 07:50:15 -04:00
NanoCode012
06674a11f2 Merge pull request #202 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix sharegpt type in doc
2023-06-14 09:48:35 +09:00
NanoCode012
3513885f43 Fix sharegpt type 2023-06-14 01:10:58 +09:00
Wing Lian
4b43a66a0b update alpaca_chat prompts for instructions to explainn the conversation 2023-06-12 18:38:38 -04:00
Wing Lian
fd2c9814c9 Merge branch 'main' into flash-optimum 2023-06-12 13:12:15 -04:00
Wing Lian
c9a149f9e8 add check for attr 2023-06-11 10:11:17 -04:00
Wing Lian
958da70376 fix formatting 2023-06-10 15:28:08 -04:00
Wing Lian
759e8673ce Update scripts/finetune.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 14:25:21 -04:00
Wing Lian
0c6f928601 address PR feedback 2023-06-10 14:23:56 -04:00
Wing Lian
eea2731a5e add streaming dataset support for pretraining datasets 2023-06-10 14:23:56 -04:00
Wing Lian
1db46a9c72 linting fix 2023-06-10 14:23:56 -04:00
Wing Lian
ab5cd28acf more gpt-neox long ctx fixes 2023-06-10 14:23:55 -04:00
Wing Lian
1a82082e91 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-10 14:23:55 -04:00
Wing Lian
1210dc8fd5 more tweaks to do pre-training with bettertransformers 2023-06-10 14:23:55 -04:00
Wing Lian
488a67d75a experimental expansion of ctx len 2023-06-10 14:23:53 -04:00
Wing Lian
71a43f8479 add validation/warning for bettertransformers and torch version 2023-06-10 14:22:31 -04:00
Wing Lian
39619028a3 use pythia-12b, neox-20b is flaky 2023-06-10 14:22:30 -04:00
Wing Lian
8792199799 add flash attn context for efficient training and attempt setting model to train mode: 2023-06-10 14:22:30 -04:00
Wing Lian
1edc30c786 add support for opimum bettertransformers 2023-06-10 14:22:30 -04:00
19 changed files with 571 additions and 57 deletions

View File

@@ -12,6 +12,7 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted
strategy:
fail-fast: false
matrix:
include:
- cuda: "118"

View File

@@ -11,6 +11,7 @@ jobs:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: cu118

View File

@@ -7,6 +7,7 @@ jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.9", "3.10"]
timeout-minutes: 10

View File

@@ -138,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations
- `sharegpt:chat`: conversations
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -264,6 +264,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
bf16: true # require >=ampere
fp16: true
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
Note: Repo does not do 4-bit quantization.
@@ -420,7 +422,15 @@ log_sweep_max_lr:
optimizer:
# specify weight decay
weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# whether to bettertransformers
flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
@@ -520,6 +530,12 @@ Add below flag to train command above
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
```
If you run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
```
## Common Errors 🧰
> Cuda out of memory

View File

@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
## Convert the JSON data files to JSONL.
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
---

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -0,0 +1,49 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:
collator_pad_to_longest: true

View File

@@ -1,7 +1,7 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM
tokenizer_type: GPTNeoXTokenizer
tokenizer_type: AutoTokenizer
trust_remote_code:
load_in_8bit: false
datasets:

View File

@@ -11,6 +11,7 @@ sentencepiece
wandb
einops
xformers
optimum
# qlora things
bert-score==0.3.13
evaluate==0.4.0

View File

@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
@@ -217,9 +218,20 @@ def train(
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
@@ -285,12 +297,15 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
sys.exit(0)
signal.signal(
signal.SIGINT,
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
logging.info("Starting trainer...")
@@ -313,13 +328,21 @@ def train(
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -20,11 +20,36 @@ def load(tokenizer, cfg):
class AlpacaConcisePrompter(AlpacaPrompter):
"""
Alpaca Prompter extending the system prompt to ask for concise answers
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
"""
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()
class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""
prompt_input = "{instruction} {input} "
prompt_no_input = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
@@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg):
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg):
def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
AlpacaChatPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -73,8 +73,17 @@ class PromptTokenizingStrategy(abc.ABC):
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
elif ( # some tokenizers automatically add an eos token, let's remove it
not add_eos_token and result["input_ids"][-1] == self.tokenizer.eos_token_id
):
result["input_ids"] = result["input_ids"][:-1]
result["attention_mask"] = result["attention_mask"][:-1]
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
if (
self.tokenizer.bos_token_id
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
@@ -96,25 +105,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
input, # pylint: disable=redefined-builtin
response,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
user_prompt = next(
iter(
self.prompter.build_prompt(
instruction,
input,
)
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
)
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_full_prompt
return tokenized_prompt
def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin
@@ -410,7 +421,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
if (
self.tokenizer.bos_token_id
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]

View File

@@ -2,13 +2,14 @@
import os
from optimum.bettertransformer import BetterTransformer
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
kwargs["model"].save_pretrained(peft_model_path)
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control

View File

@@ -1,10 +1,11 @@
"""Module containing data utilities"""
import functools
import logging
from hashlib import md5
from pathlib import Path
from typing import List, Tuple, Union
import torch
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -394,8 +395,127 @@ def load_prepare_datasets(
index=cfg.dataset_shard_idx,
)
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
if cfg.val_set_size:
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples):
res = tokenizer(
examples["text"],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
logging.debug(len(ret["input_ids"]))
return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
return dataset

View File

@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import PreTrainedModel # noqa: F401
from optimum.bettertransformer import BetterTransformer
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -70,7 +72,7 @@ def load_tokenizer(
def load_model(
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
@@ -121,9 +123,9 @@ def load_model(
logging.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if cfg.bf16:
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16:
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
@@ -251,11 +253,16 @@ def load_model(
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
):
config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
@@ -278,6 +285,7 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
@@ -287,6 +295,16 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
):
logging.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -332,6 +350,9 @@ def load_model(
logging.warning("there are no parameters that require gradient updates")
model.config.use_cache = False
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
# TODO resume_from_checkpoint handling
return model, lora_config

View File

@@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler
@@ -112,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
@@ -228,6 +240,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True,
}

View File

@@ -2,6 +2,8 @@
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -62,7 +64,42 @@ def validate_config(cfg):
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bloat16 is not True:
logging.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".")[0]) < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
logging.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -6,8 +6,12 @@ from pathlib import Path
from transformers import AutoTokenizer
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompter
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
logging.basicConfig(level="INFO")
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
)
def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
@@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
"output": "world!",
}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(3186)
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(6324)
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100
if __name__ == "__main__":
unittest.main()

View File

@@ -212,3 +212,104 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{
"flash_optimum": True,
"adapter": "lora",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"flash_optimum": True,
"fp16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"flash_optimum": True,
"bf16": True,
}
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adamw_epsilon": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adafactor",
"adamw_beta1": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adamw_beta1": 0.0001,
"adamw_beta2": 0.0001,
"adamw_epsilon": 0.0001,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"optimizer": "adafactor",
}
)
validate_config(cfg)