feat: add eos_tokens and train_on_eot for chat_template EOT parsing (#2364)
* feat: add eos_tokens and train_on_eot for chat_template EOT parsing * fix: comments * chore: add some examples of tokens * feat: add new potential errors for chat_template to faq * feat: add examples for EOT handling * fix: change error to warning for missing EOS * fix: warning typo * feat: add tests for eot token handling * fix: remove broken caplog capture in test * fix: chattemplate strategy with kd missing eot changes
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -53,14 +55,6 @@ class TestChatTemplateConfigurations:
|
||||
Test class for various configurations of ChatTemplateStrategy.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def find_sublist(full_list, sub_list):
|
||||
token_count = len(sub_list)
|
||||
for index in range(len(full_list) - token_count + 1):
|
||||
if full_list[index : index + token_count] == sub_list:
|
||||
return index
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def setup_tokenizer(
|
||||
tokenizer_name,
|
||||
@@ -68,6 +62,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template_jinja=None,
|
||||
eos_token=None,
|
||||
request=None,
|
||||
eot_token=None,
|
||||
) -> tuple[PreTrainedTokenizer, str]:
|
||||
"""
|
||||
Helper function to set up the tokenizer and chat template for the test.
|
||||
@@ -88,6 +83,10 @@ class TestChatTemplateConfigurations:
|
||||
"CodeLlamaTokenizerFast",
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
|
||||
if eot_token:
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||
|
||||
return tokenizer, chat_template_jinja
|
||||
|
||||
def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):
|
||||
@@ -974,3 +973,311 @@ class TestChatTemplateConfigurations:
|
||||
raise ValueError(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
def test_eot_tokens_conflict_with_eos_token(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
request,
|
||||
):
|
||||
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
|
||||
LOG.info(
|
||||
"Testing conflict between eot_tokens containing eos_token and train_on_eot/train_on_eos mismatch"
|
||||
)
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
# Create a situation where eot_tokens contains eos_token
|
||||
eot_tokens = [
|
||||
tokenizer.eos_token,
|
||||
"[/INST]",
|
||||
] # Deliberately including eos_token
|
||||
|
||||
# Create conflicting train_on_eos and train_on_eot settings
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=".*eos_token is in eot_tokens and train_on_eos != train_on_eot.*",
|
||||
):
|
||||
ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="none", # Setting to none
|
||||
train_on_eot="turn", # Different from train_on_eos
|
||||
eot_tokens=eot_tokens,
|
||||
)
|
||||
|
||||
def test_eot_token_backward_compatibility(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
request,
|
||||
):
|
||||
"""Test that eot_tokens inherits from eos_token when not specified"""
|
||||
LOG.info("Testing backward compatibility that eot_token inherits eos_token")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="turn", # Setting train_on_eos to "turn"
|
||||
)
|
||||
|
||||
# In backward compatibility mode, eot_tokens should be derived from eos_token
|
||||
assert strategy.eot_tokens == [
|
||||
tokenizer.eos_token
|
||||
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||
assert (
|
||||
strategy.train_on_eot == "turn"
|
||||
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||
|
||||
def test_token_not_in_template(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test runs even when tokens are not found in the template"""
|
||||
LOG.info("Testing runs even when tokens are not found in template")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
# Create a non-existent token that definitely won't be in the template
|
||||
non_existent_token = "[DEFINITELY_NOT_IN_TEMPLATE]"
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": [non_existent_token]}
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
eot_tokens=[non_existent_token],
|
||||
)
|
||||
|
||||
# Force template check by calling tokenize_prompt
|
||||
strategy.tokenize_prompt(basic_dataset[0])
|
||||
|
||||
# We can also check that a warning was logged, but there's
|
||||
# caplog conflicts when running with other tests
|
||||
# assert any(
|
||||
# "not found in chat_template" in record.message for record in self._caplog.records
|
||||
# ), "Expected warning about token not found in template was not logged"
|
||||
|
||||
def test_custom_eot_tokens(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token, # pylint: disable=unused-argument
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test with custom EOT tokens to ensure proper masking and training"""
|
||||
LOG.info("Testing with custom EOT tokens")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, None, request
|
||||
)
|
||||
|
||||
# Add custom EOT tokens to the tokenizer
|
||||
custom_eot = "[EOT]"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [custom_eot]})
|
||||
|
||||
# Create a custom chat template that uses our EOT token
|
||||
custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content'] }}{% elif message['role'] == 'user' %}User: {{ message['content'] }}{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}[EOT]{% endif %}{% endfor %}"""
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=custom_template,
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eot="turn", # Train on EOT token after each turn
|
||||
eot_tokens=[custom_eot],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Find indices of the EOT token
|
||||
eot_token_id = tokenizer.convert_tokens_to_ids(custom_eot)
|
||||
eot_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eot_token_id
|
||||
]
|
||||
|
||||
assert len(eot_indices) > 0, "Expected at least one EOT token in the input"
|
||||
|
||||
# Verify labeling for EOT tokens based on role
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
assistant_turn_indices = []
|
||||
non_assistant_turn_indices = []
|
||||
|
||||
for i, turn in enumerate(basic_dataset[0]["conversations"]):
|
||||
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
|
||||
if start_idx != -1 and end_idx != -1: # If turn is found
|
||||
if turn["from"] == "assistant":
|
||||
assistant_turn_indices.append((start_idx, end_idx))
|
||||
else:
|
||||
non_assistant_turn_indices.append((start_idx, end_idx))
|
||||
|
||||
# Check EOT tokens after assistant turns are labeled
|
||||
for eot_idx in eot_indices:
|
||||
is_after_assistant = any(
|
||||
start_idx <= eot_idx <= end_idx + 1 # +1 to include the EOT token
|
||||
for start_idx, end_idx in assistant_turn_indices
|
||||
)
|
||||
|
||||
if is_after_assistant:
|
||||
assert (
|
||||
labels[eot_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||
else:
|
||||
assert (
|
||||
labels[eot_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||
|
||||
def test_multiple_train_on_eot_settings(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test different train_on_eot settings"""
|
||||
LOG.info("Testing different train_on_eot settings")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
# Create a list to test different train_on_eot settings
|
||||
test_settings = [
|
||||
("none", lambda idx, is_assistant: False), # Never train on EOT
|
||||
("all", lambda idx, is_assistant: True), # Always train on EOT
|
||||
(
|
||||
"turn",
|
||||
lambda idx, is_assistant: is_assistant,
|
||||
), # Train on EOT after assistant turns
|
||||
("last", lambda idx, is_last: is_last), # Only train on last EOT
|
||||
]
|
||||
|
||||
for setting, expected_train_func in test_settings:
|
||||
LOG.info(f"Testing train_on_eot='{setting}'")
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eot=setting,
|
||||
eot_tokens=[
|
||||
tokenizer.eos_token
|
||||
], # Use eos_token as the EOT token for simplicity
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert (
|
||||
len(eos_indices) > 0
|
||||
), "Expected at least one EOS/EOT token in the input"
|
||||
|
||||
# Check labeling for each EOS/EOT token
|
||||
for idx, eos_idx in enumerate(eos_indices):
|
||||
# Find which turn this EOS token belongs to
|
||||
preceding_turn = None
|
||||
for i, turn in enumerate(basic_dataset[0]["conversations"]):
|
||||
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
|
||||
if (
|
||||
start_idx != -1
|
||||
and end_idx != -1
|
||||
and start_idx <= eos_idx <= end_idx + 1
|
||||
):
|
||||
preceding_turn = turn
|
||||
break
|
||||
|
||||
is_assistant = (
|
||||
preceding_turn is not None and preceding_turn["from"] == "assistant"
|
||||
)
|
||||
is_last = idx == len(eos_indices) - 1
|
||||
|
||||
expected_label = not expected_train_func(
|
||||
idx, is_assistant if setting != "last" else is_last
|
||||
)
|
||||
|
||||
if expected_label:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||
else:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
|
||||
Reference in New Issue
Block a user