experimental llama 2 chat support (#296)
* experimental llama 2 chat support * few small fixes * llama2_chat * small fix to follow original implementation * small fixes and added fixtures/tests * fix -mixed up inference and finetuning conversations * args - small fix * small fix * small adjustment and warning * fix with pre-commit --------- Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bb53a165f5
commit
3392270544
1
tests/fixtures/conversation.tokenized_llama2chat.json
vendored
Normal file
1
tests/fixtures/conversation.tokenized_llama2chat.json
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -4,13 +4,17 @@ import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
SystemDataPrompter,
|
||||
)
|
||||
from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
@@ -135,5 +139,85 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
assert example["input_ids"][9] == 11889 # USER
|
||||
|
||||
|
||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
# woraround because official Meta repos are not open
|
||||
|
||||
def test_llama2_chat_integration(self):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
prompter = Llama2ChatPrompter()
|
||||
strat = LLama2ChatTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
4096,
|
||||
)
|
||||
example = strat.tokenize_prompt(conversation)
|
||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def compare_with_transformers_integration(self):
|
||||
# this needs transformers >= v4.31.0
|
||||
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
# from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
|
||||
# broken as of 23/7/20
|
||||
# see https://github.com/huggingface/transformers/pull/24935
|
||||
# pylint: disable=C0103
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
|
||||
user_input = []
|
||||
answers = []
|
||||
for msg in conversation["conversations"]:
|
||||
if msg["from"] == "human":
|
||||
user_input.append(msg["value"])
|
||||
else:
|
||||
answers.append(msg["value"])
|
||||
hf_conf = Conversation(
|
||||
text=user_input[-1],
|
||||
past_user_inputs=[B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + user_input[0]]
|
||||
+ user_input[1:-1],
|
||||
generated_responses=answers,
|
||||
)
|
||||
# pylint: disable=W0212
|
||||
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
|
||||
|
||||
self.assertEqual(
|
||||
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user