Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
This commit is contained in:
@@ -20,11 +20,36 @@ def load(tokenizer, cfg):
|
|||||||
|
|
||||||
class AlpacaConcisePrompter(AlpacaPrompter):
|
class AlpacaConcisePrompter(AlpacaPrompter):
|
||||||
"""
|
"""
|
||||||
Alpaca Prompter extending the system prompt to ask for concise answers
|
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
|
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
|
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class AlpacaChatPrompter(AlpacaPrompter):
|
||||||
|
"""
|
||||||
|
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||||
|
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||||
|
|
||||||
|
def __init__(self): # pylint: disable=super-init-not-called
|
||||||
|
self.prompt_style = PromptStyle.CHAT.value
|
||||||
|
self.match_prompt_style()
|
||||||
|
|
||||||
|
|
||||||
|
class NoSystemPrompter(AlpacaPrompter):
|
||||||
|
"""
|
||||||
|
Null Prompter with no system prompts
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_input = "{instruction} {input} "
|
||||||
|
prompt_no_input = "{instruction} "
|
||||||
|
|
||||||
|
def __init__(self): # pylint: disable=super-init-not-called
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
@@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg):
|
|||||||
|
|
||||||
def load_qa(tokenizer, cfg):
|
def load_qa(tokenizer, cfg):
|
||||||
return AlpacaQAPromptTokenizingStrategy(
|
return AlpacaQAPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
AlpacaChatPrompter(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
@@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg):
|
|||||||
|
|
||||||
def load_camel_ai(tokenizer, cfg):
|
def load_camel_ai(tokenizer, cfg):
|
||||||
return CamelAIPromptTokenizingStrategy(
|
return CamelAIPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
AlpacaChatPrompter(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
response,
|
response,
|
||||||
) = self.parse_instruction_fields(prompt)
|
) = self.parse_instruction_fields(prompt)
|
||||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
user_prompt = next(
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
iter(
|
||||||
if not self.train_on_inputs:
|
self.prompter.build_prompt(
|
||||||
user_prompt = next(
|
instruction,
|
||||||
iter(
|
input,
|
||||||
self.prompter.build_prompt(
|
|
||||||
instruction,
|
|
||||||
input,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
)
|
||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||||
|
if not self.train_on_inputs:
|
||||||
|
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
tokenized_full_prompt["labels"] = [
|
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||||
-100
|
tokenized_res_prompt = self._tokenize(
|
||||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
response, strip_bos_token=True, add_eos_token=True
|
||||||
|
)
|
||||||
|
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||||
|
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||||
|
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||||
|
|
||||||
return tokenized_full_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
def _build_full_prompt(
|
def _build_full_prompt(
|
||||||
self, instruction, input, response # pylint: disable=redefined-builtin
|
self, instruction, input, response # pylint: disable=redefined-builtin
|
||||||
|
|||||||
@@ -6,8 +6,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||||
from axolotl.prompters import ShareGPTPrompter
|
from axolotl.prompt_tokenizers import (
|
||||||
|
AlpacaPromptTokenizingStrategy,
|
||||||
|
ShareGPTPromptTokenizingStrategy,
|
||||||
|
)
|
||||||
|
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
|
||||||
|
|
||||||
logging.basicConfig(level="INFO")
|
logging.basicConfig(level="INFO")
|
||||||
|
|
||||||
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_sharegpt_integration(self):
|
def test_sharegpt_integration(self):
|
||||||
print(Path(__file__).parent)
|
|
||||||
with open(
|
with open(
|
||||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||||
) as fin:
|
) as fin:
|
||||||
@@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||||
|
|
||||||
|
def test_no_sys_prompt(self):
|
||||||
|
"""
|
||||||
|
tests the interface between the user and assistant parts
|
||||||
|
"""
|
||||||
|
prompter = NoSystemPrompter()
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
strat = AlpacaPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
sample = {
|
||||||
|
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
||||||
|
"output": "world!",
|
||||||
|
}
|
||||||
|
example = strat.tokenize_prompt(sample)
|
||||||
|
world_idx = example["input_ids"].index(3186)
|
||||||
|
assert example["labels"][world_idx] == 3186
|
||||||
|
assert example["labels"][world_idx - 1] == -100
|
||||||
|
|
||||||
|
def test_alpaca(self):
|
||||||
|
"""
|
||||||
|
tests the interface between the user and assistant parts
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
prompter = AlpacaPrompter()
|
||||||
|
strat = AlpacaPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
||||||
|
example = strat.tokenize_prompt(sample)
|
||||||
|
world_idx = example["input_ids"].index(6324)
|
||||||
|
assert example["labels"][world_idx] == 6324
|
||||||
|
assert example["labels"][world_idx - 1] == -100
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user