apply black formatting
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg):
|
def load(strategy, tokenizer, cfg):
|
||||||
try:
|
try:
|
||||||
load_fn = "load"
|
load_fn = "load"
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import (
|
||||||
|
AlpacaPromptTokenizingStrategy,
|
||||||
|
InstructionPromptTokenizingStrategy,
|
||||||
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg):
|
||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.chat.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
AlpacaPrompter(PromptStyle.chat.value),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -19,5 +25,8 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|||||||
|
|
||||||
def load_qa(tokenizer, cfg):
|
def load_qa(tokenizer, cfg):
|
||||||
return AlpacaQAPromptTokenizingStrategy(
|
return AlpacaQAPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.chat.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
AlpacaPrompter(PromptStyle.chat.value),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,5 +4,8 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
|
|||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg):
|
||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
AlpacaPrompter(PromptStyle.instruct),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
|||||||
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
question = prompt["instruction"]
|
question = prompt["instruction"]
|
||||||
answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset
|
answer = prompt[
|
||||||
|
"revision"
|
||||||
|
] # don't use prompt[answer], that's data we don't want in the dataset
|
||||||
return (
|
return (
|
||||||
question,
|
question,
|
||||||
"",
|
"",
|
||||||
@@ -48,8 +50,12 @@ Answer: {answer}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
|
scores = yaml.dump(
|
||||||
critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
|
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||||
|
)
|
||||||
|
critiques = yaml.dump(
|
||||||
|
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||||
|
)
|
||||||
evaluation = scores + critiques
|
evaluation = scores + critiques
|
||||||
question = prompt["instruction"]
|
question = prompt["instruction"]
|
||||||
answer = prompt["answer"]
|
answer = prompt["answer"]
|
||||||
@@ -76,13 +82,19 @@ Evaluation:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
|
scores = yaml.dump(
|
||||||
critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
|
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||||
|
)
|
||||||
|
critiques = yaml.dump(
|
||||||
|
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||||
|
)
|
||||||
evaluation = scores + critiques
|
evaluation = scores + critiques
|
||||||
question = prompt["instruction"]
|
question = prompt["instruction"]
|
||||||
answer = prompt["answer"]
|
answer = prompt["answer"]
|
||||||
return (
|
return (
|
||||||
self.user_prompt.format(question=question, answer=answer, evaluation=evaluation),
|
self.user_prompt.format(
|
||||||
|
question=question, answer=answer, evaluation=evaluation
|
||||||
|
),
|
||||||
"",
|
"",
|
||||||
prompt["revision"],
|
prompt["revision"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,20 +30,34 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
# this should include a bos token, no eos token, strip trailing "\n<START>"
|
# this should include a bos token, no eos token, strip trailing "\n<START>"
|
||||||
if message.endswith("\n<START>"):
|
if message.endswith("\n<START>"):
|
||||||
message = message[:-8]
|
message = message[:-8]
|
||||||
res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False)
|
res = self._tokenize(
|
||||||
|
prefix + "Persona: " + message.strip(),
|
||||||
|
add_eos_token=False,
|
||||||
|
strip_bos_token=False,
|
||||||
|
)
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif role == "human":
|
elif role == "human":
|
||||||
prefix = "<|user|>"
|
prefix = "<|user|>"
|
||||||
res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True)
|
res = self._tokenize(
|
||||||
|
prefix + " " + message.strip(),
|
||||||
|
add_eos_token=False,
|
||||||
|
strip_bos_token=True,
|
||||||
|
)
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif role == "bot":
|
elif role == "bot":
|
||||||
prefix = "<|model|>"
|
prefix = "<|model|>"
|
||||||
res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
|
res = self._tokenize(
|
||||||
|
prefix + " " + message.strip(),
|
||||||
|
add_eos_token=True,
|
||||||
|
strip_bos_token=True,
|
||||||
|
)
|
||||||
# mask out the prefix token, rest is not masked out from labels
|
# mask out the prefix token, rest is not masked out from labels
|
||||||
# make sure we create the labels first, otherwise we get incorrect lengths
|
# make sure we create the labels first, otherwise we get incorrect lengths
|
||||||
labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):]
|
labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
|
||||||
|
*copy.deepcopy(res["input_ids"])
|
||||||
|
][len(self.bot_prefix_token_ids) :]
|
||||||
else:
|
else:
|
||||||
logging.warning(f"unknown role in conversation: {role}")
|
logging.warning(f"unknown role in conversation: {role}")
|
||||||
res = defaultdict(lambda: [])
|
res = defaultdict(lambda: [])
|
||||||
@@ -51,8 +65,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
input_len = len(input_ids)
|
input_len = len(input_ids)
|
||||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||||
result["attention_mask"][current_len : current_len + input_len] = [
|
result["attention_mask"][current_len : current_len + input_len] = [
|
||||||
1 if x != self.tokenizer.pad_token_id else 0
|
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
||||||
for x in input_ids
|
|
||||||
]
|
]
|
||||||
result["labels"][current_len : current_len + input_len] = labels
|
result["labels"][current_len : current_len + input_len] = labels
|
||||||
current_len += input_len
|
current_len += input_len
|
||||||
@@ -74,10 +87,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
if (
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
|
|||||||
@@ -59,10 +59,14 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt = next(iter(self.prompter.build_prompt(
|
user_prompt = next(
|
||||||
instruction,
|
iter(
|
||||||
input,
|
self.prompter.build_prompt(
|
||||||
)))
|
instruction,
|
||||||
|
input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
@@ -73,11 +77,15 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
|
|
||||||
def _build_full_prompt(self, instruction, input, response):
|
def _build_full_prompt(self, instruction, input, response):
|
||||||
return next(iter(self.prompter.build_prompt(
|
return next(
|
||||||
instruction,
|
iter(
|
||||||
input,
|
self.prompter.build_prompt(
|
||||||
response,
|
instruction,
|
||||||
)))
|
input,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
@@ -95,10 +103,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
if (
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
@@ -201,10 +206,14 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
)
|
)
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt = next(iter(self.prompter.build_prompt(
|
user_prompt = next(
|
||||||
instruction,
|
iter(
|
||||||
input,
|
self.prompter.build_prompt(
|
||||||
)))
|
instruction,
|
||||||
|
input,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
@@ -215,13 +224,17 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
|
|
||||||
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
||||||
return next(iter(self.prompter.build_prompt(
|
return next(
|
||||||
instruction,
|
iter(
|
||||||
input,
|
self.prompter.build_prompt(
|
||||||
output,
|
instruction,
|
||||||
reflection,
|
input,
|
||||||
corrected,
|
output,
|
||||||
)))
|
reflection,
|
||||||
|
corrected,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True):
|
def _tokenize(self, prompt, add_eos_token=True):
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
@@ -265,21 +278,27 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
user_token = self._get_user_token()
|
user_token = self._get_user_token()
|
||||||
assistant_token = self._get_assistant_token()
|
assistant_token = self._get_assistant_token()
|
||||||
try:
|
try:
|
||||||
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
|
for i, part in enumerate(
|
||||||
|
self.prompter.build_prompt(prompt["conversations"])
|
||||||
|
):
|
||||||
if isinstance(part, tuple):
|
if isinstance(part, tuple):
|
||||||
if part[0] == "USER:":
|
if part[0] == "USER:":
|
||||||
part = part[0] + part[1] if not user_token else part[1]
|
part = part[0] + part[1] if not user_token else part[1]
|
||||||
# this is still the user query, we should
|
# this is still the user query, we should
|
||||||
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
|
res = self._tokenize(
|
||||||
|
part.strip(), add_eos_token=False, strip_bos_token=True
|
||||||
|
)
|
||||||
if user_token:
|
if user_token:
|
||||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif part[0] == "ASSISTANT:":
|
elif part[0] == "ASSISTANT:":
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
part = part[0] + part[1] if not assistant_token else part[1]
|
part = part[0] + part[1] if not assistant_token else part[1]
|
||||||
# this should be the assistent response, should end with an eos token
|
# this should be the assistent response, should end with an eos token
|
||||||
res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
|
res = self._tokenize(
|
||||||
|
part.strip(), add_eos_token=True, strip_bos_token=True
|
||||||
|
)
|
||||||
if assistant_token:
|
if assistant_token:
|
||||||
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
||||||
# not masked out from labels
|
# not masked out from labels
|
||||||
@@ -288,15 +307,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
logging.warning("unhandled role: " + part[0])
|
logging.warning("unhandled role: " + part[0])
|
||||||
else:
|
else:
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
|
res = self._tokenize(
|
||||||
|
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||||
|
)
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
input_len = len(input_ids)
|
input_len = len(input_ids)
|
||||||
result["input_ids"][current_len : current_len + input_len] = input_ids
|
result["input_ids"][current_len : current_len + input_len] = input_ids
|
||||||
result["attention_mask"][current_len : current_len + input_len] = [
|
result["attention_mask"][current_len : current_len + input_len] = [
|
||||||
1 if x != self.tokenizer.pad_token_id else 0
|
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
||||||
for x in input_ids
|
|
||||||
]
|
]
|
||||||
result["labels"][current_len : current_len + input_len] = labels
|
result["labels"][current_len : current_len + input_len] = labels
|
||||||
current_len += input_len
|
current_len += input_len
|
||||||
@@ -320,10 +340,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
if (
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
|
|||||||
@@ -23,12 +23,22 @@ class AlpacaPrompter:
|
|||||||
|
|
||||||
def match_prompt_style(self):
|
def match_prompt_style(self):
|
||||||
if self.prompt_style == PromptStyle.instruct.value:
|
if self.prompt_style == PromptStyle.instruct.value:
|
||||||
self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
self.prompt_input = (
|
||||||
self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
|
self.system_prompt
|
||||||
|
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
self.prompt_no_input = (
|
||||||
|
self.system_no_input_prompt
|
||||||
|
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
|
)
|
||||||
self.response_split = "### Response:"
|
self.response_split = "### Response:"
|
||||||
if self.prompt_style == PromptStyle.chat.value:
|
if self.prompt_style == PromptStyle.chat.value:
|
||||||
self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
self.prompt_input = (
|
||||||
self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||||
|
)
|
||||||
|
self.prompt_no_input = (
|
||||||
|
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||||
|
)
|
||||||
self.response_split = "ASSISTANT:"
|
self.response_split = "ASSISTANT:"
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
@@ -55,12 +65,15 @@ class UnpromptedPrompter(AlpacaPrompter):
|
|||||||
system_prompt = ""
|
system_prompt = ""
|
||||||
system_no_input_prompt = ""
|
system_no_input_prompt = ""
|
||||||
|
|
||||||
|
|
||||||
class JeopardyPrompter(AlpacaPrompter):
|
class JeopardyPrompter(AlpacaPrompter):
|
||||||
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
|
||||||
|
|
||||||
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
||||||
system_prompt = "Choose the answer that best answers the question. Explain your reasoning."
|
system_prompt = (
|
||||||
|
"Choose the answer that best answers the question. Explain your reasoning."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
||||||
@@ -68,11 +81,15 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
|||||||
|
|
||||||
|
|
||||||
class SummarizeTLDRPrompter(AlpacaPrompter):
|
class SummarizeTLDRPrompter(AlpacaPrompter):
|
||||||
prompt_no_input = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
prompt_no_input = (
|
||||||
|
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionPrompter(AlpacaPrompter):
|
class CompletionPrompter(AlpacaPrompter):
|
||||||
def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
|
def build_prompt(
|
||||||
|
self, instruction: str, input=None, output=None
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
yield instruction
|
yield instruction
|
||||||
|
|
||||||
def get_response(self, output: str) -> str:
|
def get_response(self, output: str) -> str:
|
||||||
@@ -91,7 +108,9 @@ class ReflectAlpacaPrompter:
|
|||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
||||||
|
|
||||||
prompt_input = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
prompt_input = (
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
|
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
||||||
response_split = "### Response:"
|
response_split = "### Response:"
|
||||||
@@ -102,14 +121,26 @@ class ReflectAlpacaPrompter:
|
|||||||
|
|
||||||
def match_prompt_style(self):
|
def match_prompt_style(self):
|
||||||
if self.prompt_style == PromptStyle.instruct.value:
|
if self.prompt_style == PromptStyle.instruct.value:
|
||||||
self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
self.prompt_input = (
|
||||||
self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
|
self.system_prompt
|
||||||
|
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
self.prompt_no_input = (
|
||||||
|
self.system_no_input_prompt
|
||||||
|
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
|
)
|
||||||
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
||||||
self.response_split = "### Final Response:"
|
self.response_split = "### Final Response:"
|
||||||
if self.prompt_style == PromptStyle.chat.value:
|
if self.prompt_style == PromptStyle.chat.value:
|
||||||
self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
self.prompt_input = (
|
||||||
self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||||
self.agent_label = "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
|
)
|
||||||
|
self.prompt_no_input = (
|
||||||
|
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||||
|
)
|
||||||
|
self.agent_label = (
|
||||||
|
"\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
|
||||||
|
)
|
||||||
self.response_split = "ASSISTANT:"
|
self.response_split = "ASSISTANT:"
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
@@ -167,7 +198,7 @@ class Conversation:
|
|||||||
yield (role + ":", " " + message)
|
yield (role + ":", " " + message)
|
||||||
else:
|
else:
|
||||||
logging.warning("role with empty message: " + role)
|
logging.warning("role with empty message: " + role)
|
||||||
yield (role + ":", )
|
yield (role + ":",)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return Conversation(
|
return Conversation(
|
||||||
@@ -199,7 +230,9 @@ conv_vicuna_v1_1 = Conversation(
|
|||||||
class ShareGPTPrompter:
|
class ShareGPTPrompter:
|
||||||
def __init__(self, prompt_style=None):
|
def __init__(self, prompt_style=None):
|
||||||
if prompt_style != PromptStyle.chat.value:
|
if prompt_style != PromptStyle.chat.value:
|
||||||
raise Exception(f"unsupported prompt_style for ShareGPTPrompter({prompt_style})")
|
raise Exception(
|
||||||
|
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||||
|
)
|
||||||
|
|
||||||
# def match_prompt_style(self):
|
# def match_prompt_style(self):
|
||||||
# if self.prompt_style == PromptStyle.chat.value:
|
# if self.prompt_style == PromptStyle.chat.value:
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ from datasets import (
|
|||||||
load_dataset,
|
load_dataset,
|
||||||
IterableDataset,
|
IterableDataset,
|
||||||
Dataset,
|
Dataset,
|
||||||
concatenate_datasets, DatasetDict,
|
concatenate_datasets,
|
||||||
|
DatasetDict,
|
||||||
)
|
)
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -33,11 +34,14 @@ from axolotl.prompters import (
|
|||||||
JeopardyPrompter,
|
JeopardyPrompter,
|
||||||
CompletionPrompter,
|
CompletionPrompter,
|
||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
SummarizeTLDRPrompter, MultipleChoiceConcisePrompter,
|
SummarizeTLDRPrompter,
|
||||||
|
MultipleChoiceConcisePrompter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
|
def load_tokenized_prepared_datasets(
|
||||||
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
|
) -> DatasetDict:
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
@@ -45,7 +49,8 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
||||||
+ "|" + tokenizer_name
|
+ "|"
|
||||||
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
@@ -57,7 +62,9 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|||||||
dataset = None
|
dataset = None
|
||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
|
dataset = load_dataset(
|
||||||
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
|
||||||
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -88,7 +95,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
if d.data_files:
|
if d.data_files:
|
||||||
ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True)
|
ds = load_dataset(
|
||||||
|
d.path,
|
||||||
|
streaming=False,
|
||||||
|
data_files=d.data_files,
|
||||||
|
use_auth_token=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
||||||
else:
|
else:
|
||||||
@@ -100,49 +112,65 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|||||||
raise Exception("unhandled dataset load")
|
raise Exception("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if d.shards:
|
if d.shards:
|
||||||
ds = ds.shuffle(seed=42)["train"].shard(
|
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
||||||
num_shards=cfg.shards, index=0
|
|
||||||
)
|
|
||||||
d_type = d.type
|
d_type = d.type
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||||
if (ds_strategy := load(d.type, tokenizer, cfg)):
|
if ds_strategy := load(d.type, tokenizer, cfg):
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
AlpacaPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "explainchoice":
|
elif d_base_type == "explainchoice":
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
MultipleChoiceExplainPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
MultipleChoiceExplainPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "concisechoice":
|
elif d_base_type == "concisechoice":
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
MultipleChoiceConcisePrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
MultipleChoiceConcisePrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "summarizetldr":
|
elif d_base_type == "summarizetldr":
|
||||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||||
SummarizeTLDRPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
SummarizeTLDRPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "jeopardy":
|
elif d_base_type == "jeopardy":
|
||||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||||
JeopardyPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
JeopardyPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "oasst":
|
elif d_base_type == "oasst":
|
||||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
AlpacaPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
@@ -166,7 +194,10 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "sharegpt":
|
elif d_base_type == "sharegpt":
|
||||||
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
ShareGPTPrompter(d_prompt_style),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
@@ -196,12 +227,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
|
dataset.push_to_hub(
|
||||||
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
|
def load_prepare_datasets(
|
||||||
|
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||||
|
) -> (Dataset, Dataset):
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -221,7 +256,8 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|||||||
+ str(max_packed_sequence_len)
|
+ str(max_packed_sequence_len)
|
||||||
+ seed
|
+ seed
|
||||||
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
||||||
+ "|" + tokenizer_name
|
+ "|"
|
||||||
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
@@ -237,7 +273,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
|
dataset = load_dataset(
|
||||||
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
|
||||||
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -254,7 +292,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
|
dataset.push_to_hub(
|
||||||
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
@@ -279,9 +319,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|||||||
d
|
d
|
||||||
for d in dataset
|
for d in dataset
|
||||||
if len(d["input_ids"]) < cfg.sequence_len
|
if len(d["input_ids"]) < cfg.sequence_len
|
||||||
and len(d["input_ids"]) > 0
|
and len(d["input_ids"]) > 0
|
||||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||||
and len(d["input_ids"]) == len(d["labels"])
|
and len(d["input_ids"]) == len(d["labels"])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -294,7 +334,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
|
dataset.push_to_hub(
|
||||||
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
AutoConfig, BitsAndBytesConfig,
|
AutoConfig,
|
||||||
|
BitsAndBytesConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -244,7 +245,9 @@ def load_model(
|
|||||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
|
if (
|
||||||
|
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
||||||
|
) and not cfg.load_4bit:
|
||||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
@@ -265,7 +268,11 @@ def load_model(
|
|||||||
m.scales = m.scales.half()
|
m.scales = m.scales.half()
|
||||||
m.bias = m.bias.half()
|
m.bias = m.bias.half()
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 and cfg.load_4bit:
|
if (
|
||||||
|
torch.cuda.device_count() > 1
|
||||||
|
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||||
|
and cfg.load_4bit
|
||||||
|
):
|
||||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||||
# so let's only set it for the 4bit, see
|
# so let's only set it for the 4bit, see
|
||||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||||
|
|||||||
@@ -17,10 +17,12 @@ from axolotl.utils.callbacks import SavePeftModelCallback
|
|||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(Trainer):
|
class OneCycleLRSchedulerTrainer(Trainer):
|
||||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
def create_scheduler(
|
||||||
optimizer=self.optimizer if optimizer is None else optimizer
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
|
):
|
||||||
num_training_steps=num_training_steps
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
num_training_steps = num_training_steps
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
self.lr_scheduler = OneCycleLR(
|
||||||
@@ -203,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
)
|
)
|
||||||
callbacks.append(early_stop_cb)
|
callbacks.append(early_stop_cb)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
|
if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
|
||||||
callbacks.append(SavePeftModelCallback)
|
callbacks.append(SavePeftModelCallback)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
@@ -214,7 +216,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
else:
|
else:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||||
|
|
||||||
trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
|
trainer_cls = (
|
||||||
|
OneCycleLRSchedulerTrainer
|
||||||
|
if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
|
||||||
|
else transformers.Trainer
|
||||||
|
)
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user