make it work with pythia in the cloud
This commit is contained in:
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
data/*.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
@@ -3,35 +3,36 @@ model_type: GPTNeoXForCausalLM
|
|||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
datasets:
|
datasets:
|
||||||
- path: ./data/alpaca_data_gpt4.jsonl
|
- path: data/alpaca_data_gpt4.jsonl
|
||||||
type: alpaca
|
type: alpaca
|
||||||
- path: ./data/vicuna_cleaned.jsonl
|
- path: data/vicuna_cleaned.jsonl
|
||||||
type: sharegpt
|
type: sharegpt
|
||||||
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
||||||
type: gpteacher
|
type: gpteacher
|
||||||
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||||
type: gpteacher
|
type: gpteacher
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
lora_r: 16
|
lora_r: 8
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- query_key_value
|
||||||
- v_proj
|
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||||
wandb_project:
|
wandb_project: pythia-1.4b-lora
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb:run_name:
|
wandb_run_name:
|
||||||
wandb_log_model: checkpoint
|
wandb_log_model: checkpoint
|
||||||
output_dir: ./lora-alpaca
|
output_dir: ./lora-alpaca
|
||||||
batch_size: 128
|
batch_size: 32
|
||||||
micro_batch_size: 8
|
micro_batch_size: 4
|
||||||
num_epochs: 5
|
num_epochs: 5
|
||||||
learning_rate: 0.0003
|
learning_rate: 0.0003
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
bf16: True
|
bf16: True
|
||||||
fp16: True
|
tf32: True
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -1,26 +1,32 @@
|
|||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import yaml
|
import yaml
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
from datasets import load_dataset, IterableDataset
|
from datasets import load_dataset, IterableDataset, Dataset
|
||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
get_peft_model,
|
get_peft_model,
|
||||||
prepare_model_for_int8_training,
|
prepare_model_for_int8_training, get_peft_model_state_dict,
|
||||||
)
|
)
|
||||||
|
from torch import nn
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
src_dir = os.path.join(project_root, 'src')
|
src_dir = os.path.join(project_root, 'src')
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
|
||||||
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
|
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||||
@@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg):
|
|||||||
if len(cfg.wandb_project) > 0:
|
if len(cfg.wandb_project) > 0:
|
||||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
if len(cfg.wandb_watch) > 0:
|
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||||
if len(cfg.wandb_log_model) > 0:
|
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
if tokenizer.__class__.__name__ == "LlamaTokenizer":
|
if tokenizer.__class__.__name__ == "LlamaTokenizer":
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
|
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
if cfg.load_in_8bit:
|
if cfg.load_in_8bit:
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
@@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
lora_alpha=cfg.lora_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=cfg.lora_target_modules,
|
target_modules=cfg.lora_target_modules,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
|
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
@@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
return model, tokenizer
|
return model, tokenizer, lora_config
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@@ -88,7 +99,7 @@ def train(
|
|||||||
):
|
):
|
||||||
# load the config from the yaml file
|
# load the config from the yaml file
|
||||||
with open(config, 'r') as f:
|
with open(config, 'r') as f:
|
||||||
cfg: AttrDict = AttrDict(yaml.load(f))
|
cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||||
# then overwrite the value
|
# then overwrite the value
|
||||||
for k, v in enumerate(kwargs):
|
for k, v in enumerate(kwargs):
|
||||||
@@ -107,23 +118,116 @@ def train(
|
|||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
|
model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
|
||||||
datasets = []
|
datasets = []
|
||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
|
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None)
|
||||||
if d.type == "alpaca":
|
if d.type == "alpaca":
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d.type == "gpteacher":
|
elif d.type == "gpteacher":
|
||||||
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d.type == "sharegpt":
|
elif d.type == "sharegpt":
|
||||||
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
|
constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len)
|
||||||
|
constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
||||||
|
test_size=cfg.val_set_size, shuffle=True, seed=42
|
||||||
|
)
|
||||||
|
|
||||||
|
print(constant_len_dataset)
|
||||||
|
train_dataset = constant_len_dataset["train"]
|
||||||
|
eval_dataset = constant_len_dataset["test"]
|
||||||
|
|
||||||
|
total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size))
|
||||||
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
logging_steps = min(int(0.005 * total_num_steps), 10)
|
||||||
|
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
||||||
|
|
||||||
|
training_args = transformers.TrainingArguments(
|
||||||
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
num_train_epochs=cfg.num_epochs,
|
||||||
|
learning_rate=cfg.learning_rate,
|
||||||
|
bf16=cfg.bf16,
|
||||||
|
tf32=cfg.tf32,
|
||||||
|
logging_steps=logging_steps,
|
||||||
|
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||||
|
save_strategy="steps",
|
||||||
|
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
||||||
|
save_steps=save_steps,
|
||||||
|
output_dir=cfg.output_dir,
|
||||||
|
save_total_limit=3,
|
||||||
|
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
|
||||||
|
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||||
|
group_by_length=cfg.group_by_length,
|
||||||
|
report_to="wandb" if cfg.use_wandb else None,
|
||||||
|
run_name=cfg.wandb_run_name if cfg.use_wandb else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||||
|
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||||
|
"weight_decay": training_args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
adam_bnb_optim = bnb.optim.Adam8bit(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||||
|
eps=training_args.adam_epsilon,
|
||||||
|
lr=training_args.learning_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
||||||
|
adam_bnb_optim,
|
||||||
|
training_args.warmup_steps,
|
||||||
|
total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = transformers.Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
args=training_args,
|
||||||
|
optimizers=(adam_bnb_optim, lr_scheduler),
|
||||||
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
|
old_state_dict = model.state_dict
|
||||||
|
model.state_dict = (
|
||||||
|
lambda self, *_, **__: get_peft_model_state_dict(
|
||||||
|
self, old_state_dict()
|
||||||
|
)
|
||||||
|
).__get__(model, type(model))
|
||||||
|
|
||||||
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
|
model = torch.compile(model)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, lambda signal, frame: (
|
||||||
|
model.save_pretrained(cfg.output_dir),
|
||||||
|
exit(0)
|
||||||
|
))
|
||||||
|
|
||||||
|
# go ahead and presave the adapter config
|
||||||
|
lora_config.save_pretrained(cfg.output_dir)
|
||||||
|
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||||
|
|
||||||
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(train)
|
fire.Fire(train)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class JsonToJsonlConverter:
|
|||||||
def convert(self, input_file_path, output_file_path):
|
def convert(self, input_file_path, output_file_path):
|
||||||
content = self.file_reader.read(input_file_path)
|
content = self.file_reader.read(input_file_path)
|
||||||
data = self.json_parser.parse(content)
|
data = self.json_parser.parse(content)
|
||||||
|
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||||
jsonl_content = self.jsonl_serializer.serialize(data)
|
jsonl_content = self.jsonl_serializer.serialize(data)
|
||||||
self.file_writer.write(jsonl_content)
|
self.file_writer.write(jsonl_content)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import IterableDataset
|
from datasets import IterableDataset
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||||
|
|
||||||
|
|
||||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||||
@@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset):
|
|||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
iterator = iter(self.dataset)
|
iterator = iter(self.dataset)
|
||||||
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
|
# Loop through the entire dataset
|
||||||
|
for example in iterator:
|
||||||
|
try:
|
||||||
|
yield self.prompt_tokenizer.tokenize_prompt(example)
|
||||||
|
except InvalidDataException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConstantLengthDataset(IterableDataset):
|
class ConstantLengthDataset(IterableDataset):
|
||||||
@@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
Args:
|
Args:
|
||||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||||
dataset (dataset.Dataset): Dataset with text files.
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
|
||||||
seq_length (int): Length of token sequences to return.
|
seq_length (int): Length of token sequences to return.
|
||||||
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
datasets,
|
datasets,
|
||||||
infinite=False,
|
|
||||||
seq_length=2048,
|
seq_length=2048,
|
||||||
num_of_sequences=1024,
|
|
||||||
chars_per_token=3.6,
|
|
||||||
):
|
):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
|
self.concat_token_id = tokenizer.eos_token_id
|
||||||
self.datasets: List[IterableDataset] = datasets
|
self.datasets: List[IterableDataset] = datasets
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.infinite = infinite
|
|
||||||
self.current_size = 0
|
|
||||||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
iterator = iter(self.datasets)
|
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
more_examples = True
|
buffer_len = 0
|
||||||
while more_examples:
|
for dataset in self.datasets:
|
||||||
buffer, buffer_len = [], 0
|
iterator = iter(dataset)
|
||||||
while True:
|
more_examples = True
|
||||||
if buffer_len >= self.max_buffer_size:
|
while more_examples:
|
||||||
break
|
|
||||||
try:
|
try:
|
||||||
buffer.append(next(iterator))
|
example = next(iterator)
|
||||||
buffer_len += len(buffer[-1])
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if self.infinite:
|
more_examples = False
|
||||||
iterator = iter(self.datasets)
|
example = None
|
||||||
else:
|
|
||||||
more_examples = False
|
add_concat_token = False
|
||||||
break
|
if example:
|
||||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
example_len = len(example["input_ids"])
|
||||||
all_token_ids = []
|
add_concat_token = example["input_ids"][-1] != self.concat_token_id
|
||||||
for tokenized_input in tokenized_inputs:
|
else:
|
||||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
example_len = 0
|
||||||
for i in range(0, len(all_token_ids), self.seq_length):
|
|
||||||
input_ids = all_token_ids[i : i + self.seq_length]
|
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
|
||||||
if len(input_ids) == self.seq_length:
|
if buffer["input_ids"]:
|
||||||
self.current_size += 1
|
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
|
||||||
yield {
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
|
||||||
"input_ids": torch.LongTensor(input_ids),
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||||
"labels": torch.LongTensor(input_ids),
|
yield {
|
||||||
"attention_masks": torch.LongTensor(input_ids),
|
"input_ids": input_ids,
|
||||||
}
|
"labels": labels,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
buffer_len = 0
|
||||||
|
|
||||||
|
if example:
|
||||||
|
input_ids = example["input_ids"]
|
||||||
|
attention_mask = example["attention_mask"]
|
||||||
|
labels = example["labels"]
|
||||||
|
|
||||||
|
if add_concat_token:
|
||||||
|
input_ids.append(self.concat_token_id)
|
||||||
|
attention_mask.append(1)
|
||||||
|
labels.append(self.concat_token_id)
|
||||||
|
|
||||||
|
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
|
||||||
|
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||||
|
|
||||||
|
buffer["input_ids"].append(input_ids_with_concat)
|
||||||
|
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||||
|
buffer["labels"].append(labels_with_concat)
|
||||||
|
buffer_len += len(input_ids)
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ LLAMA_DEFAULT_BOS_TOKEN = "<s>"
|
|||||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
|
LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDataException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PromptTokenizingStrategy(abc.ABC):
|
class PromptTokenizingStrategy(abc.ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -32,7 +36,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
full_prompt = self._tokenize_full_prompt(prompt)
|
full_prompt = self._tokenize_full_prompt(prompt)
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt = self.prompter.generate_prompt(
|
user_prompt = self.prompter.build_prompt(
|
||||||
prompt["instruction"], prompt["input"]
|
prompt["instruction"], prompt["input"]
|
||||||
)
|
)
|
||||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||||
@@ -43,7 +47,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
|
|
||||||
def _tokenize_full_prompt(self, prompt):
|
def _tokenize_full_prompt(self, prompt):
|
||||||
return self.prompter.generate_prompt(
|
return self.prompter.build_prompt(
|
||||||
prompt["instruction"],
|
prompt["instruction"],
|
||||||
prompt["input"],
|
prompt["input"],
|
||||||
prompt["output"],
|
prompt["output"],
|
||||||
@@ -71,7 +75,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
|
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
|
||||||
def _tokenize_full_prompt(self, prompt):
|
def _tokenize_full_prompt(self, prompt):
|
||||||
return self.prompter.generate_prompt(
|
return self.prompter.build_prompt(
|
||||||
prompt["instruction"],
|
prompt["instruction"],
|
||||||
prompt["input"],
|
prompt["input"],
|
||||||
prompt["response"],
|
prompt["response"],
|
||||||
@@ -80,4 +84,7 @@ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
|
|||||||
|
|
||||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
pass
|
try:
|
||||||
|
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
|
||||||
|
except (KeyError, AssertionError) as e:
|
||||||
|
raise InvalidDataException(str(e))
|
||||||
|
|||||||
@@ -1,10 +1,160 @@
|
|||||||
|
import copy
|
||||||
|
import dataclasses
|
||||||
|
from enum import auto, Enum
|
||||||
|
from typing import List, Tuple, Any, Union
|
||||||
|
|
||||||
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPrompter:
|
class AlpacaPrompter:
|
||||||
pass
|
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
|
response_split = "### Response:"
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None,
|
||||||
|
output: Union[None, str] = None,
|
||||||
|
) -> str:
|
||||||
|
# returns the full prompt from instruction and optional input
|
||||||
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
|
if input:
|
||||||
|
res = self.prompt_input.format(
|
||||||
|
instruction=instruction, input=input
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
res = self.prompt_no_input.format(
|
||||||
|
instruction=instruction
|
||||||
|
)
|
||||||
|
if output:
|
||||||
|
res = f"{res}{output}"
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_response(self, output: str) -> str:
|
||||||
|
return output.split(self.response_split)[1].strip()
|
||||||
|
|
||||||
|
|
||||||
|
class GPTeacherPrompter(AlpacaPrompter):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SeparatorStyle(Enum):
|
||||||
|
"""Different separator style."""
|
||||||
|
SINGLE = auto()
|
||||||
|
TWO = auto()
|
||||||
|
DOLLY = auto()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO clean this 💩 up
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Conversation:
|
||||||
|
"""A class that keeps all conversation history."""
|
||||||
|
system: str
|
||||||
|
roles: List[str]
|
||||||
|
messages: List[List[str]]
|
||||||
|
offset: int
|
||||||
|
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||||
|
sep: str = "###"
|
||||||
|
sep2: str = None
|
||||||
|
|
||||||
|
def get_prompt(self):
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
ret = self.system + seps[0]
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
ret += role + ": " + message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
ret += role + ":"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return Conversation(
|
||||||
|
system=self.system,
|
||||||
|
roles=self.roles,
|
||||||
|
messages=[[x, y] for x, y in self.messages],
|
||||||
|
offset=self.offset,
|
||||||
|
sep_style=self.sep_style,
|
||||||
|
sep=self.sep,
|
||||||
|
sep2=self.sep2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def append_message(self, role, message):
|
||||||
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
|
||||||
|
conv_vicuna_v1_1 = Conversation(
|
||||||
|
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||||
|
roles=["USER", "ASSISTANT"],
|
||||||
|
messages=[],
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.TWO,
|
||||||
|
sep=" ",
|
||||||
|
sep2="</s>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter:
|
class ShareGPTPrompter:
|
||||||
pass
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
source,
|
||||||
|
tokenizer
|
||||||
|
):
|
||||||
|
if len(source) < 2:
|
||||||
|
# If there isn't a back and forth conversation, ignore it
|
||||||
|
# also happens on the data splitting leaving empty conversations
|
||||||
|
raise IndexError
|
||||||
|
|
||||||
|
conv = conv_vicuna_v1_1.copy()
|
||||||
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||||
|
|
||||||
class GPTeacherPrompter:
|
try:
|
||||||
pass
|
# Apply prompt templates
|
||||||
|
if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
|
||||||
|
# Skip the first one if it is not from human
|
||||||
|
source = source[1:]
|
||||||
|
except IndexError as e:
|
||||||
|
# sometimes there is a bing or system chat
|
||||||
|
raise e
|
||||||
|
|
||||||
|
conv.messages = []
|
||||||
|
for j, sentence in enumerate(source):
|
||||||
|
role = roles[sentence["from"]]
|
||||||
|
assert role == conv.roles[j % 2]
|
||||||
|
conv.append_message(role, sentence["value"])
|
||||||
|
conversation = conv.get_prompt()
|
||||||
|
|
||||||
|
# Tokenize conversations
|
||||||
|
tokenized_result = tokenizer(
|
||||||
|
conversation,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048, # FIXME
|
||||||
|
padding=False,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
target = copy.deepcopy(tokenized_result["input_ids"])
|
||||||
|
|
||||||
|
# Mask targets
|
||||||
|
sep = conv.sep + conv.roles[1] + ": "
|
||||||
|
|
||||||
|
rounds = conversation.split(conv.sep2)
|
||||||
|
cur_len = 1
|
||||||
|
for i, rou in enumerate(rounds):
|
||||||
|
if rou == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
parts = rou.split(sep)
|
||||||
|
if len(parts) != 2:
|
||||||
|
break
|
||||||
|
parts[0] += sep
|
||||||
|
round_len = len(tokenizer(rou)["input_ids"])
|
||||||
|
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
|
||||||
|
target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
|
||||||
|
|
||||||
|
cur_len += round_len
|
||||||
|
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
||||||
|
attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
|
||||||
|
|
||||||
|
return dict(input_ids=tokenized_result["input_ids"], labels=target,
|
||||||
|
attention_mask=attention_mask)
|
||||||
|
|||||||
Reference in New Issue
Block a user