make it work with pythia in the cloud

This commit is contained in:
Wing Lian
2023-04-14 07:24:55 -04:00
parent ce24f5e246
commit 8d959a7e26
7 changed files with 352 additions and 70 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
data/*.jsonl filter=lfs diff=lfs merge=lfs -text

View File

@@ -3,35 +3,36 @@ model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
load_in_8bit: true load_in_8bit: true
datasets: datasets:
- path: ./data/alpaca_data_gpt4.jsonl - path: data/alpaca_data_gpt4.jsonl
type: alpaca type: alpaca
- path: ./data/vicuna_cleaned.jsonl - path: data/vicuna_cleaned.jsonl
type: sharegpt type: sharegpt
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher type: gpteacher
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher type: gpteacher
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
sequence_len: 2048 sequence_len: 2048
lora_r: 16 lora_r: 8
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_modules: lora_target_modules:
- q_proj - query_key_value
- v_proj lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project: pythia-1.4b-lora
wandb_watch: wandb_watch:
wandb:run_name: wandb_run_name:
wandb_log_model: checkpoint wandb_log_model: checkpoint
output_dir: ./lora-alpaca output_dir: ./lora-alpaca
batch_size: 128 batch_size: 32
micro_batch_size: 8 micro_batch_size: 4
num_epochs: 5 num_epochs: 5
learning_rate: 0.0003 learning_rate: 0.0003
train_on_inputs: false train_on_inputs: false
group_by_length: false
bf16: True bf16: True
fp16: True tf32: True
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
deepspeed: deepspeed:

View File

@@ -1,26 +1,32 @@
import math
import os import os
import signal
import sys import sys
from pathlib import Path from pathlib import Path
import bitsandbytes as bnb
import fire import fire
import torch import torch
import transformers import transformers
import yaml import yaml
from attrdict import AttrDict from attrdict import AttrDict
from datasets import load_dataset, IterableDataset from datasets import load_dataset, IterableDataset, Dataset
from peft import ( from peft import (
LoraConfig, LoraConfig,
get_peft_model, get_peft_model,
prepare_model_for_int8_training, prepare_model_for_int8_training, get_peft_model_state_dict,
) )
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from transformers.trainer_pt_utils import get_parameter_names
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
src_dir = os.path.join(project_root, 'src') src_dir = os.path.join(project_root, 'src')
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
@@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg):
if len(cfg.wandb_project) > 0: if len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True cfg.use_wandb = True
if len(cfg.wandb_watch) > 0: if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch os.environ["WANDB_WATCH"] = cfg.wandb_watch
if len(cfg.wandb_log_model) > 0: if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
@@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
if tokenizer.__class__.__name__ == "LlamaTokenizer": if tokenizer.__class__.__name__ == "LlamaTokenizer":
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.load_in_8bit: if cfg.load_in_8bit:
model = prepare_model_for_int8_training(model) model = prepare_model_for_int8_training(model)
@@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
lora_alpha=cfg.lora_alpha, lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules, target_modules=cfg.lora_target_modules,
lora_dropout=cfg.lora_dropout, lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
@@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
model.print_trainable_parameters() model.print_trainable_parameters()
return model, tokenizer return model, tokenizer, lora_config
def train( def train(
@@ -88,7 +99,7 @@ def train(
): ):
# load the config from the yaml file # load the config from the yaml file
with open(config, 'r') as f: with open(config, 'r') as f:
cfg: AttrDict = AttrDict(yaml.load(f)) cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
for k, v in enumerate(kwargs): for k, v in enumerate(kwargs):
@@ -107,23 +118,116 @@ def train(
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
# Load the model and tokenizer # Load the model and tokenizer
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter) model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
datasets = [] datasets = []
for d in cfg.datasets: for d in cfg.datasets:
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None) ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None)
if d.type == "alpaca": if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d.type == "gpteacher": elif d.type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d.type == "sharegpt": elif d.type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len)
constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
test_size=cfg.val_set_size, shuffle=True, seed=42
)
print(constant_len_dataset)
train_dataset = constant_len_dataset["train"]
eval_dataset = constant_len_dataset["test"]
total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size))
warmup_steps = min(int(0.03 * total_num_steps), 100)
logging_steps = min(int(0.005 * total_num_steps), 10)
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
warmup_steps=warmup_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
bf16=cfg.bf16,
tf32=cfg.tf32,
logging_steps=logging_steps,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_name if cfg.use_wandb else None,
)
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
"weight_decay": training_args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
adam_bnb_optim = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
adam_bnb_optim,
training_args.warmup_steps,
total_num_steps,
)
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
optimizers=(adam_bnb_optim, lr_scheduler),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(
self, old_state_dict()
)
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
signal.signal(signal.SIGINT, lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
exit(0)
))
# go ahead and presave the adapter config
lora_config.save_pretrained(cfg.output_dir)
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
model.save_pretrained(cfg.output_dir)
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(train) fire.Fire(train)

View File

@@ -44,6 +44,7 @@ class JsonToJsonlConverter:
def convert(self, input_file_path, output_file_path): def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path) content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content) data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
jsonl_content = self.jsonl_serializer.serialize(data) jsonl_content = self.jsonl_serializer.serialize(data)
self.file_writer.write(jsonl_content) self.file_writer.write(jsonl_content)

View File

@@ -2,7 +2,7 @@ from typing import List
import torch import torch
from datasets import IterableDataset from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
# We want this to be a wrapper for an existing dataset that we have loaded # We want this to be a wrapper for an existing dataset that we have loaded
@@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset):
def __iter__(self): def __iter__(self):
iterator = iter(self.dataset) iterator = iter(self.dataset)
yield self.prompt_tokenizer.tokenize_prompt(next(iterator)) # Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
except InvalidDataException:
pass
class ConstantLengthDataset(IterableDataset): class ConstantLengthDataset(IterableDataset):
@@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset):
Args: Args:
tokenizer (Tokenizer): The processor used for proccessing the data. tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files. dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return. seq_length (int): Length of token sequences to return.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
""" """
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
datasets, datasets,
infinite=False,
seq_length=2048, seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
def __iter__(self): def __iter__(self):
iterator = iter(self.datasets) buffer = {"input_ids": [], "attention_mask": [], "labels": []}
more_examples = True buffer_len = 0
while more_examples: for dataset in self.datasets:
buffer, buffer_len = [], 0 iterator = iter(dataset)
while True: more_examples = True
if buffer_len >= self.max_buffer_size: while more_examples:
break
try: try:
buffer.append(next(iterator)) example = next(iterator)
buffer_len += len(buffer[-1])
except StopIteration: except StopIteration:
if self.infinite: more_examples = False
iterator = iter(self.datasets) example = None
else:
more_examples = False add_concat_token = False
break if example:
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] example_len = len(example["input_ids"])
all_token_ids = [] add_concat_token = example["input_ids"][-1] != self.concat_token_id
for tokenized_input in tokenized_inputs: else:
all_token_ids.extend(tokenized_input + [self.concat_token_id]) example_len = 0
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length] if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
if len(input_ids) == self.seq_length: if buffer["input_ids"]:
self.current_size += 1 input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
yield { attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
"input_ids": torch.LongTensor(input_ids), labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
"labels": torch.LongTensor(input_ids), yield {
"attention_masks": torch.LongTensor(input_ids), "input_ids": input_ids,
} "labels": labels,
"attention_mask": attention_mask,
}
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
if example:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
labels_with_concat = torch.tensor(labels, dtype=torch.long)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer_len += len(input_ids)

View File

@@ -9,6 +9,10 @@ LLAMA_DEFAULT_BOS_TOKEN = "<s>"
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class InvalidDataException(Exception):
pass
class PromptTokenizingStrategy(abc.ABC): class PromptTokenizingStrategy(abc.ABC):
def __init__( def __init__(
self, self,
@@ -32,7 +36,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
full_prompt = self._tokenize_full_prompt(prompt) full_prompt = self._tokenize_full_prompt(prompt)
tokenized_full_prompt = self._tokenize(full_prompt) tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt = self.prompter.generate_prompt( user_prompt = self.prompter.build_prompt(
prompt["instruction"], prompt["input"] prompt["instruction"], prompt["input"]
) )
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
@@ -43,7 +47,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt): def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt( return self.prompter.build_prompt(
prompt["instruction"], prompt["instruction"],
prompt["input"], prompt["input"],
prompt["output"], prompt["output"],
@@ -71,7 +75,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy): class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt): def _tokenize_full_prompt(self, prompt):
return self.prompter.generate_prompt( return self.prompter.build_prompt(
prompt["instruction"], prompt["instruction"],
prompt["input"], prompt["input"],
prompt["response"], prompt["response"],
@@ -80,4 +84,7 @@ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
pass try:
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
except (KeyError, AssertionError) as e:
raise InvalidDataException(str(e))

View File

@@ -1,10 +1,160 @@
import copy
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any, Union
IGNORE_TOKEN_ID = -100
class AlpacaPrompter: class AlpacaPrompter:
pass prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
response_split = "### Response:"
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None,
output: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(
instruction=instruction, input=input
)
else:
res = self.prompt_no_input.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
return res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class GPTeacherPrompter(AlpacaPrompter):
...
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
def get_prompt(self):
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
class ShareGPTPrompter: class ShareGPTPrompter:
pass def build_prompt(
self,
source,
tokenizer
):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError
conv = conv_vicuna_v1_1.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
class GPTeacherPrompter: try:
pass # Apply prompt templates
if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as e:
# sometimes there is a bing or system chat
raise e
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["value"])
conversation = conv.get_prompt()
# Tokenize conversations
tokenized_result = tokenizer(
conversation,
truncation=True,
max_length=2048, # FIXME
padding=False,
return_tensors=None,
)
target = copy.deepcopy(tokenized_result["input_ids"])
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
rounds = conversation.split(conv.sep2)
cur_len = 1
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou)["input_ids"])
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
cur_len += round_len
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
return dict(input_ids=tokenized_result["input_ids"], labels=target,
attention_mask=attention_mask)