wip add new proposed message structure (#1904)
* wip add new proposed message structure * tokenization * wip * wip transform builder * wip make the chat dataset loadable * wip chatml + llama 3 new chat objects * chore: lint * chore: lint * fix tokenization * remove dacite dependency since we're using pydantic now * fix handling when already correctly split in messages * make sure to remove chat features from tokenized ds * move chat to be a input transform for messages * make sure llama3 has the bos token * remove non-working special token code * fix messages strat loader
This commit is contained in:
0
tests/core/chat/__init__.py
Normal file
0
tests/core/chat/__init__.py
Normal file
0
tests/core/chat/format/__init__.py
Normal file
0
tests/core/chat/format/__init__.py
Normal file
197
tests/core/chat/test_messages.py
Normal file
197
tests/core/chat/test_messages.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Tests for the chat messages module
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers import AddedToken, AutoTokenizer
|
||||
|
||||
from axolotl.core.chat.format.chatml import format_message
|
||||
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||
def llama_tokenizer_fixture():
|
||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||
def llama_tokenizer_w_chatml(llama_tokenizer):
|
||||
llama_tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
llama_tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return llama_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", name="chat_msgs")
|
||||
def chat_msgs_fixture():
|
||||
return {
|
||||
"conversation": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "value": "You are a helpful assistant."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "What is today's stock price of Apple?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_call",
|
||||
"value": {
|
||||
"name": "get_date",
|
||||
"arguments": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool_call",
|
||||
"value": {
|
||||
"name": "get_stock_price",
|
||||
"arguments": {"symbol": "AAPL"},
|
||||
},
|
||||
},
|
||||
],
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_response",
|
||||
"value": {
|
||||
"name": "get_date",
|
||||
"content": {"date": "2024-09-09"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool_response",
|
||||
"value": {
|
||||
"name": "get_stock_price",
|
||||
"content": {"symbol": "AAPL", "price": 123.45},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": "The stock price of Apple is $123.45.\n",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
|
||||
},
|
||||
],
|
||||
"weight": 1,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class TestMessagesCase:
|
||||
"""
|
||||
Test cases for the chat messages module
|
||||
"""
|
||||
|
||||
def test_tool_call_stringify(self, chat_msgs):
|
||||
chat_msgs_as_obj = Chats(**chat_msgs)
|
||||
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
|
||||
chat_msgs_as_obj.conversation[2].content[1].value
|
||||
)
|
||||
|
||||
def test_chatml_formatted_wrapper(self, chat_msgs):
|
||||
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||
target_chatml = """<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
What is today's stock price of Apple?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_date", "arguments": {}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
|
||||
</tool_call>
|
||||
<|im_end|>
|
||||
<|im_start|>tool
|
||||
<tool_response>
|
||||
{"name": "get_date", "content": {"date": "2024-09-09"}}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The stock price of Apple is $123.45.
|
||||
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
|
||||
assert target_chatml == str(chat_msg_formatted)
|
||||
|
||||
def test_chatml_formatting_tool_call(self, chat_msgs):
|
||||
chat_msgs_as_obj = Chats(**chat_msgs)
|
||||
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
|
||||
assert target_chatml_turn2 == str(
|
||||
format_message(chat_msgs_as_obj.conversation[2])
|
||||
)
|
||||
|
||||
def test_train_labels(self, chatml_tokenizer, chat_msgs):
|
||||
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
|
||||
# fmt: off
|
||||
target_labels = [
|
||||
-100, -100, -100, # role
|
||||
27, 14506, 13735, 397, 5018, 609, 794,
|
||||
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
|
||||
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
|
||||
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
|
||||
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
|
||||
128256, # <|im_end|>
|
||||
-100 # trailing newline
|
||||
]
|
||||
# fmt: on
|
||||
assert tokenized["labels"] == target_labels
|
||||
|
||||
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
|
||||
# also test if indivudal contents are set not to train
|
||||
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
|
||||
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
|
||||
# fmt: off
|
||||
target_labels = [
|
||||
-100, -100, -100, # role
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
|
||||
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
|
||||
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
|
||||
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
|
||||
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
|
||||
4513, 13, 1774, 13,
|
||||
128256, # <|im_end|>
|
||||
-100, # trailing newline
|
||||
]
|
||||
# fmt: on
|
||||
assert tokenized["labels"] == target_labels
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
tests/prompt_strategies/messages/__init__.py
Normal file
0
tests/prompt_strategies/messages/__init__.py
Normal file
62
tests/prompt_strategies/messages/test_chat.py
Normal file
62
tests/prompt_strategies/messages/test_chat.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
from axolotl.prompt_strategies.messages.chat import load
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TestMessagesChatLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
|
||||
"""
|
||||
|
||||
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
||||
strategy = load(
|
||||
llama3_tokenizer,
|
||||
DictDefault(
|
||||
{
|
||||
"train_on_inputs": False,
|
||||
"sequence_len": 512,
|
||||
}
|
||||
),
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"field_messages": "messages",
|
||||
}
|
||||
),
|
||||
)
|
||||
res = strategy.wrap_dataset(assistant_dataset)
|
||||
input_ids = res[0]["input_ids"]
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user