Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels

Fix tokenizing labels
This commit is contained in:
Wing Lian
2023-06-15 08:13:43 -04:00
committed by GitHub
3 changed files with 92 additions and 23 deletions

View File

@@ -20,11 +20,36 @@ def load(tokenizer, cfg):
class AlpacaConcisePrompter(AlpacaPrompter): class AlpacaConcisePrompter(AlpacaPrompter):
""" """
Alpaca Prompter extending the system prompt to ask for concise answers Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
""" """
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n" system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n" system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
class AlpacaChatPrompter(AlpacaPrompter):
"""
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
"""
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
def __init__(self): # pylint: disable=super-init-not-called
self.prompt_style = PromptStyle.CHAT.value
self.match_prompt_style()
class NoSystemPrompter(AlpacaPrompter):
"""
Null Prompter with no system prompts
"""
prompt_input = "{instruction} {input} "
prompt_no_input = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
@@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg):
def load_qa(tokenizer, cfg): def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy( return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaChatPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg):
def load_camel_ai(tokenizer, cfg): def load_camel_ai(tokenizer, cfg):
return CamelAIPromptTokenizingStrategy( return CamelAIPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaChatPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
response, response,
) = self.parse_instruction_fields(prompt) ) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response) user_prompt = next(
tokenized_full_prompt = self._tokenize(full_prompt) iter(
if not self.train_on_inputs: self.prompter.build_prompt(
user_prompt = next( instruction,
iter( input,
self.prompter.build_prompt(
instruction,
input,
)
) )
) )
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) )
user_prompt_len = len(tokenized_user_prompt["input_ids"]) tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_prompt["labels"] = [-100] * user_prompt_len
-100 tokenized_res_prompt = self._tokenize(
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] response, strip_bos_token=True, add_eos_token=True
)
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
return tokenized_full_prompt return tokenized_prompt
def _build_full_prompt( def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin self, instruction, input, response # pylint: disable=redefined-builtin

View File

@@ -6,8 +6,12 @@ from pathlib import Path
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompters import ShareGPTPrompter from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
logging.basicConfig(level="INFO") logging.basicConfig(level="INFO")
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) )
def test_sharegpt_integration(self): def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open( with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin: ) as fin:
@@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields]) self.assertEqual(example[fields], tokenized_conversation[fields])
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
"output": "world!",
}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(3186)
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
def test_alpaca(self):
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
example = strat.tokenize_prompt(sample)
world_idx = example["input_ids"].index(6324)
assert example["labels"][world_idx] == 6324
assert example["labels"][world_idx - 1] == -100
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()