apply black formatting

This commit is contained in:
Wing Lian
2023-05-24 22:59:33 -04:00
parent ce694e20a3
commit ce34d64e8a
10 changed files with 248 additions and 108 deletions

View File

@@ -1,5 +1,6 @@
import importlib import importlib
def load(strategy, tokenizer, cfg): def load(strategy, tokenizer, cfg):
try: try:
load_fn = "load" load_fn = "load"

View File

@@ -1,10 +1,16 @@
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg): def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len AlpacaPrompter(PromptStyle.chat.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
@@ -19,5 +25,8 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def load_qa(tokenizer, cfg): def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy( return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len AlpacaPrompter(PromptStyle.chat.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )

View File

@@ -4,5 +4,8 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg): def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len AlpacaPrompter(PromptStyle.instruct),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )

View File

@@ -7,7 +7,9 @@ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): def parse_instruction_fields(self, prompt) -> (str, str, str):
question = prompt["instruction"] question = prompt["instruction"]
answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset answer = prompt[
"revision"
] # don't use prompt[answer], that's data we don't want in the dataset
return ( return (
question, question,
"", "",
@@ -48,8 +50,12 @@ Answer: {answer}
""" """
def parse_instruction_fields(self, prompt) -> (str, str, str): def parse_instruction_fields(self, prompt) -> (str, str, str):
scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper) scores = yaml.dump(
critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper) prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
)
critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
)
evaluation = scores + critiques evaluation = scores + critiques
question = prompt["instruction"] question = prompt["instruction"]
answer = prompt["answer"] answer = prompt["answer"]
@@ -76,13 +82,19 @@ Evaluation:
""" """
def parse_instruction_fields(self, prompt) -> (str, str, str): def parse_instruction_fields(self, prompt) -> (str, str, str):
scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper) scores = yaml.dump(
critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper) prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
)
critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
)
evaluation = scores + critiques evaluation = scores + critiques
question = prompt["instruction"] question = prompt["instruction"]
answer = prompt["answer"] answer = prompt["answer"]
return ( return (
self.user_prompt.format(question=question, answer=answer, evaluation=evaluation), self.user_prompt.format(
question=question, answer=answer, evaluation=evaluation
),
"", "",
prompt["revision"], prompt["revision"],
) )

View File

@@ -30,20 +30,34 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
# this should include a bos token, no eos token, strip trailing "\n<START>" # this should include a bos token, no eos token, strip trailing "\n<START>"
if message.endswith("\n<START>"): if message.endswith("\n<START>"):
message = message[:-8] message = message[:-8]
res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False) res = self._tokenize(
prefix + "Persona: " + message.strip(),
add_eos_token=False,
strip_bos_token=False,
)
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif role == "human": elif role == "human":
prefix = "<|user|>" prefix = "<|user|>"
res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True) res = self._tokenize(
prefix + " " + message.strip(),
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif role == "bot": elif role == "bot":
prefix = "<|model|>" prefix = "<|model|>"
res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True) res = self._tokenize(
prefix + " " + message.strip(),
add_eos_token=True,
strip_bos_token=True,
)
# mask out the prefix token, rest is not masked out from labels # mask out the prefix token, rest is not masked out from labels
# make sure we create the labels first, otherwise we get incorrect lengths # make sure we create the labels first, otherwise we get incorrect lengths
labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):] labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
*copy.deepcopy(res["input_ids"])
][len(self.bot_prefix_token_ids) :]
else: else:
logging.warning(f"unknown role in conversation: {role}") logging.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: []) res = defaultdict(lambda: [])
@@ -51,8 +65,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
input_len = len(input_ids) input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [ result["attention_mask"][current_len : current_len + input_len] = [
1 if x != self.tokenizer.pad_token_id else 0 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
for x in input_ids
] ]
result["labels"][current_len : current_len + input_len] = labels result["labels"][current_len : current_len + input_len] = labels
current_len += input_len current_len += input_len
@@ -74,10 +87,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if ( if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]

View File

@@ -59,10 +59,14 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
full_prompt = self._build_full_prompt(instruction, input, response) full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt) tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt = next(iter(self.prompter.build_prompt( user_prompt = next(
instruction, iter(
input, self.prompter.build_prompt(
))) instruction,
input,
)
)
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
@@ -73,11 +77,15 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response): def _build_full_prompt(self, instruction, input, response):
return next(iter(self.prompter.build_prompt( return next(
instruction, iter(
input, self.prompter.build_prompt(
response, instruction,
))) input,
response,
)
)
)
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer( result = self.tokenizer(
@@ -95,10 +103,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if ( if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]
@@ -201,10 +206,14 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
) )
tokenized_full_prompt = self._tokenize(full_prompt) tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt = next(iter(self.prompter.build_prompt( user_prompt = next(
instruction, iter(
input, self.prompter.build_prompt(
))) instruction,
input,
)
)
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
@@ -215,13 +224,17 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, output, reflection, corrected): def _build_full_prompt(self, instruction, input, output, reflection, corrected):
return next(iter(self.prompter.build_prompt( return next(
instruction, iter(
input, self.prompter.build_prompt(
output, instruction,
reflection, input,
corrected, output,
))) reflection,
corrected,
)
)
)
def _tokenize(self, prompt, add_eos_token=True): def _tokenize(self, prompt, add_eos_token=True):
result = self.tokenizer( result = self.tokenizer(
@@ -265,21 +278,27 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
user_token = self._get_user_token() user_token = self._get_user_token()
assistant_token = self._get_assistant_token() assistant_token = self._get_assistant_token()
try: try:
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): for i, part in enumerate(
self.prompter.build_prompt(prompt["conversations"])
):
if isinstance(part, tuple): if isinstance(part, tuple):
if part[0] == "USER:": if part[0] == "USER:":
part = part[0] + part[1] if not user_token else part[1] part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should # this is still the user query, we should
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True) res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=True
)
if user_token: if user_token:
res["input_ids"] = [user_token, *res["input_ids"]] res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == "ASSISTANT:": elif part[0] == "ASSISTANT:":
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
part = part[0] + part[1] if not assistant_token else part[1] part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token # this should be the assistent response, should end with an eos token
res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True) res = self._tokenize(
part.strip(), add_eos_token=True, strip_bos_token=True
)
if assistant_token: if assistant_token:
res["input_ids"] = [assistant_token, *res["input_ids"]] res["input_ids"] = [assistant_token, *res["input_ids"]]
# not masked out from labels # not masked out from labels
@@ -288,15 +307,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
logging.warning("unhandled role: " + part[0]) logging.warning("unhandled role: " + part[0])
else: else:
# this is only ever the first part, should include the bos token and the user query # this is only ever the first part, should include the bos token and the user query
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False) res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
input_ids = res["input_ids"] input_ids = res["input_ids"]
input_len = len(input_ids) input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [ result["attention_mask"][current_len : current_len + input_len] = [
1 if x != self.tokenizer.pad_token_id else 0 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
for x in input_ids
] ]
result["labels"][current_len : current_len + input_len] = labels result["labels"][current_len : current_len + input_len] = labels
current_len += input_len current_len += input_len
@@ -320,10 +340,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if ( if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]

View File

@@ -23,12 +23,22 @@ class AlpacaPrompter:
def match_prompt_style(self): def match_prompt_style(self):
if self.prompt_style == PromptStyle.instruct.value: if self.prompt_style == PromptStyle.instruct.value:
self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" self.prompt_input = (
self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n" self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:" self.response_split = "### Response:"
if self.prompt_style == PromptStyle.chat.value: if self.prompt_style == PromptStyle.chat.value:
self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" self.prompt_input = (
self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def build_prompt( def build_prompt(
@@ -55,12 +65,15 @@ class UnpromptedPrompter(AlpacaPrompter):
system_prompt = "" system_prompt = ""
system_no_input_prompt = "" system_no_input_prompt = ""
class JeopardyPrompter(AlpacaPrompter): class JeopardyPrompter(AlpacaPrompter):
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
class MultipleChoiceExplainPrompter(AlpacaPrompter): class MultipleChoiceExplainPrompter(AlpacaPrompter):
system_prompt = "Choose the answer that best answers the question. Explain your reasoning." system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning."
)
class MultipleChoiceConcisePrompter(AlpacaPrompter): class MultipleChoiceConcisePrompter(AlpacaPrompter):
@@ -68,11 +81,15 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
class SummarizeTLDRPrompter(AlpacaPrompter): class SummarizeTLDRPrompter(AlpacaPrompter):
prompt_no_input = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" prompt_no_input = (
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
)
class CompletionPrompter(AlpacaPrompter): class CompletionPrompter(AlpacaPrompter):
def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]: def build_prompt(
self, instruction: str, input=None, output=None
) -> Generator[str, None, None]:
yield instruction yield instruction
def get_response(self, output: str) -> str: def get_response(self, output: str) -> str:
@@ -91,7 +108,9 @@ class ReflectAlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
prompt_input = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" prompt_input = (
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n" prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
response_split = "### Response:" response_split = "### Response:"
@@ -102,14 +121,26 @@ class ReflectAlpacaPrompter:
def match_prompt_style(self): def match_prompt_style(self):
if self.prompt_style == PromptStyle.instruct.value: if self.prompt_style == PromptStyle.instruct.value:
self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" self.prompt_input = (
self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n" self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
self.response_split = "### Final Response:" self.response_split = "### Final Response:"
if self.prompt_style == PromptStyle.chat.value: if self.prompt_style == PromptStyle.chat.value:
self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" self.prompt_input = (
self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
self.agent_label = "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:" )
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.agent_label = (
"\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
)
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def build_prompt( def build_prompt(
@@ -167,7 +198,7 @@ class Conversation:
yield (role + ":", " " + message) yield (role + ":", " " + message)
else: else:
logging.warning("role with empty message: " + role) logging.warning("role with empty message: " + role)
yield (role + ":", ) yield (role + ":",)
def copy(self): def copy(self):
return Conversation( return Conversation(
@@ -199,7 +230,9 @@ conv_vicuna_v1_1 = Conversation(
class ShareGPTPrompter: class ShareGPTPrompter:
def __init__(self, prompt_style=None): def __init__(self, prompt_style=None):
if prompt_style != PromptStyle.chat.value: if prompt_style != PromptStyle.chat.value:
raise Exception(f"unsupported prompt_style for ShareGPTPrompter({prompt_style})") raise Exception(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
# def match_prompt_style(self): # def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value: # if self.prompt_style == PromptStyle.chat.value:

View File

@@ -7,7 +7,8 @@ from datasets import (
load_dataset, load_dataset,
IterableDataset, IterableDataset,
Dataset, Dataset,
concatenate_datasets, DatasetDict, concatenate_datasets,
DatasetDict,
) )
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -33,11 +34,14 @@ from axolotl.prompters import (
JeopardyPrompter, JeopardyPrompter,
CompletionPrompter, CompletionPrompter,
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
SummarizeTLDRPrompter, MultipleChoiceConcisePrompter, SummarizeTLDRPrompter,
MultipleChoiceConcisePrompter,
) )
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict: def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( md5(
@@ -45,7 +49,8 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
+ "|" + tokenizer_name + "|"
+ tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
) )
@@ -57,7 +62,9 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
dataset = None dataset = None
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True) dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
)
dataset = dataset["train"] dataset = dataset["train"]
except: except:
pass pass
@@ -88,7 +95,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
) )
elif ds_from_hub: elif ds_from_hub:
if d.data_files: if d.data_files:
ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True) ds = load_dataset(
d.path,
streaming=False,
data_files=d.data_files,
use_auth_token=True,
)
else: else:
ds = load_dataset(d.path, streaming=False, use_auth_token=True) ds = load_dataset(d.path, streaming=False, use_auth_token=True)
else: else:
@@ -100,49 +112,65 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
raise Exception("unhandled dataset load") raise Exception("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
ds = ds.shuffle(seed=42)["train"].shard( ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
num_shards=cfg.shards, index=0
)
d_type = d.type d_type = d.type
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if (ds_strategy := load(d.type, tokenizer, cfg)): if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "alpaca": elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy( ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "explainchoice": elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len MultipleChoiceExplainPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "concisechoice": elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len MultipleChoiceConcisePrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr": elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy( ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len SummarizeTLDRPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "jeopardy": elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy( ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len JeopardyPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "oasst": elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy( ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
@@ -166,7 +194,10 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "sharegpt": elif d_base_type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy( ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ShareGPTPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
@@ -196,12 +227,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
logging.info( logging.info(
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
return dataset return dataset
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset): def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
) -> (Dataset, Dataset):
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -221,7 +256,8 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
+ str(max_packed_sequence_len) + str(max_packed_sequence_len)
+ seed + seed
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
+ "|" + tokenizer_name + "|"
+ tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
) )
@@ -237,7 +273,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
logging.info( logging.info(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True) dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
)
dataset = dataset["train"] dataset = dataset["train"]
except: except:
pass pass
@@ -254,7 +292,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
logging.info( logging.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
else: else:
dataset = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
@@ -279,9 +319,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
d d
for d in dataset for d in dataset
if len(d["input_ids"]) < cfg.sequence_len if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"]) and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"]) and len(d["input_ids"]) == len(d["labels"])
] ]
) )
@@ -294,7 +334,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
logging.info( logging.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
else: else:
dataset = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path

View File

@@ -11,7 +11,8 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
PreTrainedModel, PreTrainedModel,
AutoConfig, BitsAndBytesConfig, AutoConfig,
BitsAndBytesConfig,
) )
try: try:
@@ -244,7 +245,9 @@ def load_model(
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit: if (
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
) and not cfg.load_4bit:
logging.info("converting PEFT model w/ prepare_model_for_int8_training") logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model) model = prepare_model_for_int8_training(model)
@@ -265,7 +268,11 @@ def load_model(
m.scales = m.scales.half() m.scales = m.scales.half()
m.bias = m.bias.half() m.bias = m.bias.half()
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 and cfg.load_4bit: if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and cfg.load_4bit
):
# llama is PROBABLY model parallelizable, but the default isn't that it is # llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see # so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133 # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133

View File

@@ -17,10 +17,12 @@ from axolotl.utils.callbacks import SavePeftModelCallback
class OneCycleLRSchedulerTrainer(Trainer): class OneCycleLRSchedulerTrainer(Trainer):
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): def create_scheduler(
optimizer=self.optimizer if optimizer is None else optimizer self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
num_warmup_steps=self.args.get_warmup_steps(num_training_steps) ):
num_training_steps=num_training_steps optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
num_training_steps = num_training_steps
pct_start = num_warmup_steps / num_training_steps pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR( self.lr_scheduler = OneCycleLR(
@@ -203,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
) )
callbacks.append(early_stop_cb) callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0 if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
@@ -214,7 +216,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else: else:
data_collator_kwargs["pad_to_multiple_of"] = 8 data_collator_kwargs["pad_to_multiple_of"] = 8
trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
else transformers.Trainer
)
trainer = trainer_cls( trainer = trainer_cls(
model=model, model=model,
train_dataset=train_dataset, train_dataset=train_dataset,