fix: use apply_chat_template to find turn boundaries and allow tool_calling field (#2179) [skip ci]
* fix: use apply_chat_template to find turn boundaries and allow tool_calling field * fix: keys to include in turn * feat(doc): explicitly recommend setting train_on_eos and roles_to_train * fix: eos not being masked for tool due to template padding * chore: clear up docs * fix: default messages format, train_on_eos: turn, and train on all assistant msg * fix: properly warn if empty content * feat: parametrize chat_template tests to test different tokenizers * fix: set proper default for message key * fix: update defaults to match load function * fix: change defaults to use new * feat: add tool_calling dataset * feat: add tool_calling test * fix: add handling of edge case of mistral tokenizer with only system prompt * feat: refactor all test to follow source code * fix: remove unnecessary eos_token from phi35 * fix test for phi3.5 since eos was dropped from chat_template --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -127,34 +127,40 @@ datasets:
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
chat_template: tokenizer_default
|
||||
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
|
||||
|
||||
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
|
||||
chat_template_jinja:
|
||||
# The key in the data example that contains the messages. Default is "messages".
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
# The key in the message turn that contains the role. Default is "role".
|
||||
# Key for role in each message (default: "role")
|
||||
message_field_role: role
|
||||
# The key in the message turn that contains the content. Default is "content".
|
||||
# Key for content in each message (default: "content")
|
||||
message_field_content: content
|
||||
# Optional[Dict[str, List]]. Roles mapping for the messages.
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant", "ai"]
|
||||
assistant: ["gpt", "assistant"]
|
||||
system: ["system"]
|
||||
tool: ["tool"]
|
||||
|
||||
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 4 fields are empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["gpt", "assistant"]
|
||||
roles_to_train: ["assistant"] # default
|
||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn: train on the EOS token at the end of each trainable turn
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
message_field_training_detail: train_detail
|
||||
|
||||
|
||||
|
||||
@@ -68,6 +68,8 @@ We recommend checking the below examples for other usecases.
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train:
|
||||
train_on_eos:
|
||||
```
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
@@ -77,7 +79,7 @@ chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
roles_to_train: ["assistant"] # default value
|
||||
```
|
||||
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
@@ -87,7 +89,6 @@ chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
@@ -99,7 +100,6 @@ chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
@@ -25,8 +25,8 @@ class ChatTemplatePrompter(Prompter):
|
||||
processor=None,
|
||||
chat_template=None,
|
||||
max_length=2048,
|
||||
message_field_role: str = "from",
|
||||
message_field_content: str = "value",
|
||||
message_field_role: str = "role",
|
||||
message_field_content: str = "content",
|
||||
message_field_training: Optional[str] = None,
|
||||
message_field_training_detail: Optional[str] = None,
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
@@ -41,6 +41,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
"assistant": "assistant",
|
||||
"gpt": "assistant",
|
||||
"system": "system",
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
self.message_field_role = message_field_role
|
||||
@@ -188,7 +189,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
_messages = "conversations"
|
||||
_messages = "messages"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -279,12 +280,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
LOG.debug(f"Should train: {should_train}")
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||
)
|
||||
|
||||
if turn_start_idx == -1 or turn_end_idx == -1:
|
||||
LOG.warning(f"Failed to find boundaries for turn {index}")
|
||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
@@ -313,8 +309,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
# Handle EOS token
|
||||
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
|
||||
if eos_idx == turn_end_idx:
|
||||
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
|
||||
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
|
||||
last_eos_idx = eos_idx
|
||||
if self.train_on_eos == "all" or (
|
||||
self.train_on_eos == "turn" and should_train
|
||||
@@ -339,75 +335,120 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
def find_eos_token(self, input_ids, start_idx):
|
||||
def find_first_eos_token(self, input_ids, start_idx):
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
for i in range(start_idx, len(input_ids)):
|
||||
if input_ids[i] == eos_token_id:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
|
||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
"""
|
||||
content = turn_content.get("content")
|
||||
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
|
||||
if turn_idx >= len(turns):
|
||||
raise ValueError(f"Turn index {turn_idx} out of range")
|
||||
|
||||
if not content_ids:
|
||||
LOG.warning(f"Empty content for turn {turn}")
|
||||
# mistral does not output message if it contains only system message
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turns[0].get("role") == "system"
|
||||
and "mistral" in self.tokenizer.name_or_path.lower()
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
# For first turn, start from beginning
|
||||
if turn == 0:
|
||||
start_search_idx = 0
|
||||
else:
|
||||
# For subsequent turns, find the previous EOS token
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
eos_count = 0
|
||||
start_search_idx = 0
|
||||
empty_turn = {
|
||||
"role": turns[turn_idx].get("role"),
|
||||
"content": "[[dummy_message]]",
|
||||
}
|
||||
|
||||
for i, token_id in enumerate(conversation_ids):
|
||||
if token_id == eos_token_id:
|
||||
eos_count += 1
|
||||
if eos_count == turn: # Find the nth EOS token where n = turn
|
||||
start_search_idx = i + 1
|
||||
break
|
||||
# Create conversation versions
|
||||
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
||||
turns_with_content = turns[: turn_idx + 1]
|
||||
|
||||
# we can optimize this to only search for a few tokens from start_search_idx
|
||||
# but it would risk missing the content if it's not found within the first few tokens or
|
||||
# if start_search_idx cannot be found above.
|
||||
last_index = len(conversation_ids) - len(content_ids) + 1
|
||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
|
||||
|
||||
if last_index < start_search_idx:
|
||||
# Generate the conversation up to the turn, with final turn included
|
||||
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
|
||||
|
||||
if not full_ids or not dummy_ids:
|
||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||
return -1, -1
|
||||
|
||||
# Find first difference (start of content)
|
||||
start_idx = None
|
||||
min_len = min(len(dummy_ids), len(full_ids))
|
||||
for i in range(min_len):
|
||||
if dummy_ids[i] != full_ids[i]:
|
||||
start_idx = i
|
||||
break
|
||||
|
||||
if start_idx is None:
|
||||
LOG.warning(f"Could not find content start boundary for turn {turn_idx}")
|
||||
return -1, -1
|
||||
|
||||
# Find last difference (end of content)
|
||||
end_idx = None
|
||||
for i in range(min_len):
|
||||
dummy_pos = len(dummy_ids) - 1 - i
|
||||
full_pos = len(full_ids) - 1 - i
|
||||
if dummy_ids[dummy_pos] != full_ids[full_pos]:
|
||||
end_idx = full_pos + 1 # Add one to include the last token when slice
|
||||
break
|
||||
|
||||
if end_idx is None:
|
||||
LOG.warning(f"Could not find content end boundary for turn {turn_idx}")
|
||||
return -1, -1
|
||||
|
||||
if end_idx < start_idx:
|
||||
LOG.warning(
|
||||
f"last_index to search is less than start_search_idx for turn {turn}"
|
||||
f"Content end boundary is before start boundary for turn {turn_idx}"
|
||||
)
|
||||
return -1, -1
|
||||
|
||||
# Search for content starting from start_search_idx
|
||||
first_elem = content_ids[0]
|
||||
for i in range(start_search_idx, last_index):
|
||||
# Quick check of first element before doing full comparison
|
||||
if conversation_ids[i] == first_elem:
|
||||
# Check if the rest of the content matches
|
||||
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||
LOG.debug(f"Found turn {turn} content at position {i}")
|
||||
return i, i + len(content_ids)
|
||||
if end_idx == start_idx:
|
||||
LOG.warning(
|
||||
f"Content end boundary is the same as start boundary for turn {turn_idx}. This is likely an empty turn."
|
||||
)
|
||||
return -1, -1
|
||||
|
||||
return -1, -1
|
||||
LOG.debug(f"Content boundaries: {start_idx}, {end_idx}")
|
||||
LOG.debug(
|
||||
f"Content tokens: {self.tokenizer.convert_ids_to_tokens(full_ids[start_idx:end_idx])}"
|
||||
)
|
||||
|
||||
return start_idx, end_idx
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = [
|
||||
{
|
||||
"role": self.prompter.roles[t[self.prompter.message_field_role]],
|
||||
"content": t[self.prompter.message_field_content],
|
||||
"training": t.get(self.prompter.message_field_training),
|
||||
"training_detail": t.get(self.prompter.message_field_training_detail),
|
||||
}
|
||||
for t in prompt[self.messages]
|
||||
turns = []
|
||||
optional_keys = [
|
||||
"tool_calls", # tool that 'assistant' calls
|
||||
"name", # name of tool given by 'tool'
|
||||
"tool_call_id", # mistral/mixtral requires this
|
||||
]
|
||||
for message in prompt[self.messages]:
|
||||
turn = {
|
||||
"role": self.prompter.roles[message[self.prompter.message_field_role]],
|
||||
"training": message.get(self.prompter.message_field_training),
|
||||
"training_detail": message.get(
|
||||
self.prompter.message_field_training_detail
|
||||
),
|
||||
}
|
||||
|
||||
# do not add content if None as it may conflict with some templates due to tools
|
||||
content = message.get(self.prompter.message_field_content, None)
|
||||
if content is not None:
|
||||
turn["content"] = content
|
||||
|
||||
for key in optional_keys:
|
||||
value = message.get(key, None)
|
||||
if value is not None:
|
||||
turn[key] = value
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
turns = turns[1:]
|
||||
@@ -446,8 +487,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", []),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", None),
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -7,6 +7,8 @@ from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
@@ -59,7 +61,52 @@ def fixture_basic_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
@pytest.fixture(name="toolcalling_dataset")
|
||||
def fixture_toolcalling_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, what's the temperature in Paris right now?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_temperature",
|
||||
"arguments": {
|
||||
"location": "Paris, France",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "get_current_temperature",
|
||||
"content": "22.0",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The temperature in Paris is 22.0 degrees Celsius.",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
||||
def fixture_llama3_tokenizer():
|
||||
hf_hub_download(
|
||||
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||
@@ -77,7 +124,53 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="phi35_tokenizer")
|
||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||
def fixture_mistralv03_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
|
||||
def fixture_phi35_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_gemma2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma2_tokenizer_chat_template_jinja")
|
||||
def fixture_gemma2_chat_template_jinja_w_system() -> str:
|
||||
return "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_2_vision_chat_template_jinja")
|
||||
def fixture_llama3_2_vision_with_hardcoded_date() -> str:
|
||||
"""Hardcodes the date in the template to avoid the need for date logic in the prompt"""
|
||||
|
||||
template = _CHAT_TEMPLATES["llama3_2_vision"]
|
||||
|
||||
old_date_logic = """{%- if not date_string is defined %}
|
||||
{%- if strftime_now is defined %}
|
||||
{%- set date_string = strftime_now("%d %b %Y") %}
|
||||
{%- else %}
|
||||
{%- set date_string = "26 Jul 2024" %}
|
||||
{%- endif %}
|
||||
{%- endif %}"""
|
||||
|
||||
new_date_logic = """{%- set date_string = "17 Dec 2024" %}"""
|
||||
|
||||
modified_template = template.replace(old_date_logic, new_date_logic)
|
||||
|
||||
return modified_template
|
||||
|
||||
@@ -140,7 +140,6 @@ class TestAssistantChatTemplateLlama3:
|
||||
1781, 26966, 32007, # user eot
|
||||
32001, # assistant
|
||||
1781, 26966, 32007, # assistant eot
|
||||
32000, # eos
|
||||
]
|
||||
expected_labels = [
|
||||
-100, # user
|
||||
@@ -151,7 +150,6 @@ class TestAssistantChatTemplateLlama3:
|
||||
-100, -100, -100, # user eot
|
||||
-100, # assistant
|
||||
1781, 26966, 32007, # assistant eot
|
||||
32000, # eos
|
||||
]
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
@@ -230,7 +228,10 @@ class TestSharegptChatTemplateLlama3:
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -238,6 +239,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["gpt"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -283,7 +285,10 @@ class TestSharegptChatTemplateLlama3:
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -291,6 +296,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -336,7 +342,10 @@ class TestSharegptChatTemplateLlama3:
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -344,6 +353,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["system", "human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -389,5 +399,148 @@ class TestSharegptChatTemplateLlama3:
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
|
||||
|
||||
class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
"""
|
||||
Test class for assistant style datasets with tool_calling prompts using the llama-32_vision chat template.
|
||||
"""
|
||||
|
||||
def test_llama32vision_train_on_assistant(
|
||||
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
|
||||
):
|
||||
LOG.info(
|
||||
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on assistant"
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="turn",
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(toolcalling_dataset[0])
|
||||
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, 271, # system header
|
||||
38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
|
||||
2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
|
||||
128006, 882, 128007, 271, # user header
|
||||
19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
|
||||
128006, 78191, 128007, 271, # assistant header
|
||||
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
||||
128006, 23799, 4690, 128007, 271, # tool header
|
||||
1, 1313, 13, 15, 1, 128009, # tool message
|
||||
128006, 78191, 128007, 271, # assistant header
|
||||
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
||||
]
|
||||
|
||||
expected_labels = [
|
||||
IGNORE_TOKEN_ID, # bos
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
|
||||
def test_llama32vision_train_on_tools(
|
||||
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
|
||||
):
|
||||
LOG.info(
|
||||
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="turn",
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant", "tool"],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(toolcalling_dataset[0])
|
||||
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, 271, # system header
|
||||
38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
|
||||
2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
|
||||
128006, 882, 128007, 271, # user header
|
||||
19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
|
||||
128006, 78191, 128007, 271, # assistant header
|
||||
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
||||
128006, 23799, 4690, 128007, 271, # tool header
|
||||
1, 1313, 13, 15, 1, 128009, # tool message
|
||||
128006, 78191, 128007, 271, # assistant header
|
||||
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
||||
]
|
||||
|
||||
expected_labels = [
|
||||
IGNORE_TOKEN_ID, # bos
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
|
||||
IGNORE_TOKEN_ID, 1313, 13, 15, IGNORE_TOKEN_ID, 128009, # tool message
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user