fix: use apply_chat_template to find turn boundaries and allow tool_calling field (#2179) [skip ci]

* fix: use apply_chat_template to find turn boundaries and allow tool_calling field

* fix: keys to include in turn

* feat(doc): explicitly recommend setting train_on_eos and roles_to_train

* fix: eos not being masked for tool due to template padding

* chore: clear up docs

* fix: default messages format, train_on_eos: turn, and train on all assistant msg

* fix: properly warn if empty content

* feat: parametrize chat_template tests to test different tokenizers

* fix: set proper default for message key

* fix: update defaults to match load function

* fix: change defaults to use new

* feat: add tool_calling dataset

* feat: add tool_calling test

* fix: add handling of edge case of mistral tokenizer with only system prompt

* feat: refactor all test to follow source code

* fix: remove unnecessary eos_token from phi35

* fix test for phi3.5 since eos was dropped from chat_template

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2024-12-18 04:42:21 +07:00
committed by GitHub
parent 339f3c67e2
commit 10cfecf02e
7 changed files with 924 additions and 352 deletions

View File

@@ -7,6 +7,8 @@ from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
@@ -59,7 +61,52 @@ def fixture_basic_dataset():
)
@pytest.fixture(name="llama3_tokenizer")
@pytest.fixture(name="toolcalling_dataset")
def fixture_toolcalling_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "system",
"content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.",
},
{
"role": "user",
"content": "Hey, what's the temperature in Paris right now?",
},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_current_temperature",
"arguments": {
"location": "Paris, France",
"unit": "celsius",
},
},
}
],
},
{
"role": "tool",
"name": "get_current_temperature",
"content": "22.0",
},
{
"role": "assistant",
"content": "The temperature in Paris is 22.0 degrees Celsius.",
},
]
}
]
)
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
def fixture_llama3_tokenizer():
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
@@ -77,7 +124,53 @@ def fixture_llama3_tokenizer():
return tokenizer
@pytest.fixture(name="phi35_tokenizer")
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
)
return tokenizer
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
def fixture_phi35_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
return tokenizer
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
def fixture_gemma2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
@pytest.fixture(name="gemma2_tokenizer_chat_template_jinja")
def fixture_gemma2_chat_template_jinja_w_system() -> str:
return "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
@pytest.fixture(name="llama3_2_vision_chat_template_jinja")
def fixture_llama3_2_vision_with_hardcoded_date() -> str:
"""Hardcodes the date in the template to avoid the need for date logic in the prompt"""
template = _CHAT_TEMPLATES["llama3_2_vision"]
old_date_logic = """{%- if not date_string is defined %}
{%- if strftime_now is defined %}
{%- set date_string = strftime_now("%d %b %Y") %}
{%- else %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- endif %}"""
new_date_logic = """{%- set date_string = "17 Dec 2024" %}"""
modified_template = template.replace(old_date_logic, new_date_logic)
return modified_template

View File

@@ -140,7 +140,6 @@ class TestAssistantChatTemplateLlama3:
1781, 26966, 32007, # user eot
32001, # assistant
1781, 26966, 32007, # assistant eot
32000, # eos
]
expected_labels = [
-100, # user
@@ -151,7 +150,6 @@ class TestAssistantChatTemplateLlama3:
-100, -100, -100, # user eot
-100, # assistant
1781, 26966, 32007, # assistant eot
32000, # eos
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
@@ -230,7 +228,10 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -238,6 +239,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["gpt"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -283,7 +285,10 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -291,6 +296,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -336,7 +342,10 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -344,6 +353,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["system", "human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -389,5 +399,148 @@ class TestSharegptChatTemplateLlama3:
), f"Labels mismatch: {labels} != {expected_labels}"
class TestAssistantToolCallingChatTemplateLlama32Vision:
"""
Test class for assistant style datasets with tool_calling prompts using the llama-32_vision chat template.
"""
def test_llama32vision_train_on_assistant(
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
):
LOG.info(
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on assistant"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template(
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
),
message_field_role="role",
message_field_content="content",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(toolcalling_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 9125, 128007, 271, # system header
38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
128006, 882, 128007, 271, # user header
19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
128006, 78191, 128007, 271, # assistant header
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
128006, 23799, 4690, 128007, 271, # tool header
1, 1313, 13, 15, 1, 128009, # tool message
128006, 78191, 128007, 271, # assistant header
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama32vision_train_on_tools(
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
):
LOG.info(
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
)
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template(
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
),
message_field_role="role",
message_field_content="content",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant", "tool"],
)
res = strategy.tokenize_prompt(toolcalling_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 9125, 128007, 271, # system header
38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
128006, 882, 128007, 271, # user header
19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
128006, 78191, 128007, 271, # assistant header
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
128006, 23799, 4690, 128007, 271, # tool header
1, 1313, 13, 15, 1, 128009, # tool message
128006, 78191, 128007, 271, # assistant header
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
IGNORE_TOKEN_ID, 1313, 13, 15, IGNORE_TOKEN_ID, 128009, # tool message
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff