Fix llama3 chat_template (extra <|eot_id|> on last turn) (#1635)

* Fix llama3 chat_template (the {{eos_token}} leads to an extra <|eot_id|> being added in the last turn). Output now matches official Llama 3 Instruct model

* add tests

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Leonard
2024-05-21 22:08:53 +09:00
committed by GitHub
parent 22ae21a6c2
commit 7c2bf3091f
2 changed files with 86 additions and 1 deletions

View File

@@ -0,0 +1,85 @@
"""
tests for chat_template prompt strategy
"""
import unittest
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates
@pytest.fixture(name="sharegpt_dataset")
def fixture_sharegpt_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversations": [
{
"from": "human",
"value": "hello",
},
{
"from": "gpt",
"value": "hello",
},
{
"from": "human",
"value": "goodbye",
},
{
"from": "gpt",
"value": "goodbye",
},
]
}
]
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
class TestSharegptChatTemplateLlama3:
"""
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
llama3_tokenizer,
False,
512,
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
if __name__ == "__main__":
unittest.main()