Add flexible configuration options for chat_template dataset training (#1756)

* Add flexible configuration options for chat dataset training

- Introduce roles_to_train parameter to set training labels by role
- Add train_on_eos option to configure training on end-of-sequence tokens
- Implement per-message training configuration in dataset
- Allow fine-grained control over training specific portions of messages
- Add message_field_training and message_field_training_detail settings
- Implement mapping between dataset character offsets and tokenized prompt
- Enhance test suite to cover new functionality

* Fix missing field inits, things weren't working from yaml.

* Add flexible configuration options for chat dataset training

- Introduce roles_to_train parameter to set training labels by role
- Add train_on_eos option to configure training on end-of-sequence tokens
- Implement per-message training configuration in dataset
- Allow fine-grained control over training specific portions of messages
- Add message_field_training and message_field_training_detail settings
- Implement mapping between dataset character offsets and tokenized prompt
- Enhance test suite to cover new functionality

* Fix missing field inits, things weren't working from yaml.

* chore: lint

* Revert test repo back to NousResearch after opening PR to fix the tokenizer_config.json.

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Adam Brusselback
2024-07-28 21:48:57 -04:00
committed by GitHub
parent 94ba93259f
commit 55cc214c76
4 changed files with 1111 additions and 92 deletions

View File

@@ -2,6 +2,7 @@
tests for chat_template prompt strategy
"""
import logging
import unittest
import pytest
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy,
load,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
{
"role": "assistant",
"content": "goodbye",
},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hello"},
{"role": "user", "content": "goodbye"},
{"role": "assistant", "content": "goodbye"},
]
}
]
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
[
{
"conversations": [
{
"from": "human",
"value": "hello",
},
{
"from": "gpt",
"value": "hello",
},
{
"from": "human",
"value": "goodbye",
},
{
"from": "gpt",
"value": "goodbye",
},
{"from": "human", "value": "hello"},
{"from": "gpt", "value": "hello"},
{"from": "human", "value": "goodbye"},
{"from": "gpt", "value": "goodbye"},
]
}
]
)
@pytest.fixture(name="basic_dataset")
def fixture_basic_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversations": [
{"from": "system", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "assistant", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "assistant", "value": "I'm doing well, thank you!"},
]
}
]
@@ -77,19 +75,611 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer
class TestChatTemplateConfigurations:
"""
Test class for various configurations of ChatTemplateStrategy.
"""
@staticmethod
def find_sublist(full_list, sub_list):
token_count = len(sub_list)
for index in range(len(full_list) - token_count + 1):
if full_list[index : index + token_count] == sub_list:
return index
return -1
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Check the behavior of human inputs
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
labeled = all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
)
LOG.debug(
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
)
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Verify that human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
LOG.debug(
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that all responses are labeled (except for special tokens)
all_responses = [
"Hello",
"Hi there!",
"How are you?",
"I'm doing well, thank you!",
]
for response in all_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
train_on_eos="none", # Add this line
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
# Verify that no labels are set when roles_to_train is empty
LOG.debug("Full labels: %s", labels)
assert all(
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="all",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="turn",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
eos_idx = start_idx + len(response_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token after assistant response '{response}' to be labeled"
# Check that EOS tokens after human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
eos_idx = start_idx + len(input_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token after human input '{input_text}' to not be labeled"
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="last",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
last_eos_idx = eos_indices[-1]
# Check that only the last EOS token is labeled
for idx in eos_indices[:-1]:
assert (
labels[idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {idx} to not be labeled"
assert (
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="none",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with drop_system_message=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
assert (
self.find_sublist(input_ids, system_ids) == -1
), "Expected system message to be dropped"
def test_custom_roles(self, llama3_tokenizer):
LOG.info("Testing with custom roles mapping")
custom_roles = {
"user": ["human", "user"],
"assistant": ["ai", "assistant"],
"system": ["context"],
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["ai"],
)
# Create a new dataset with modified role names
modified_conversations = [
{"from": "context", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "ai", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "ai", "value": "I'm doing well, thank you!"},
]
modified_dataset = Dataset.from_dict(
{"conversations": [modified_conversations]}
)
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Check if AI responses are labeled correctly
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in ai_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for AI response '{response}' to be set"
# Check if human messages are not labeled
human_messages = ["Hello", "How are you?"]
for message in human_messages:
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, message_ids)
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(message_ids)]
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
def test_message_field_training(self, llama3_tokenizer):
LOG.info("Testing with message_field_training")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
)
# Create a new dataset with the train and train_detail fields
modified_conversation = [
{"from": "system", "value": "You are an AI assistant.", "train": False},
{"from": "human", "value": "Hello", "train": False},
{"from": "assistant", "value": "Hello", "train": True},
{"from": "human", "value": "How are you?", "train": True},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": False},
{"begin_offset": 9, "end_offset": 18, "train": True},
{"begin_offset": 19, "end_offset": 30, "train": False},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": False,
},
{"from": "assistant", "value": "Hi there!", "train": True},
]
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Function to find all occurrences of a sublist
def find_all_sublists(full_list, sub_list):
indices = []
for index in range(len(full_list) - len(sub_list) + 1):
if full_list[index : index + len(sub_list)] == sub_list:
indices.append(index)
return indices
# Keep track of which occurrences we've processed
processed_occurrences = {}
# Check if messages are labeled correctly based on train or train_detail
for i, turn in enumerate(modified_conversation):
turn_tokens = llama3_tokenizer.encode(
turn["value"], add_special_tokens=False
)
occurrences = find_all_sublists(input_ids, turn_tokens)
turn_key = turn["value"]
if turn_key not in processed_occurrences:
processed_occurrences[turn_key] = 0
current_occurrence = processed_occurrences[turn_key]
if current_occurrence >= len(occurrences):
assert (
False
), f"Not enough occurrences found for message: {turn['value']}"
start_idx = occurrences[current_occurrence]
processed_occurrences[turn_key] += 1
end_idx = start_idx + len(turn_tokens)
LOG.debug(
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
)
if "train_detail" in turn:
# Get token offsets
tokenized_output = llama3_tokenizer(
turn["value"], return_offsets_mapping=True, add_special_tokens=False
)
token_offsets = tokenized_output["offset_mapping"]
# Adjust token offsets as done in the implementation
for i in range(len(token_offsets) - 1):
token_offsets[i] = (
token_offsets[i][0],
token_offsets[i + 1][0] - 1,
)
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
# Adjust train_details
adjusted_train_details = strategy.prompter.adjust_train_details(
turn["train_detail"], token_offsets
)
LOG.debug(f"Original train_details: {turn['train_detail']}")
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
# Handle train_detail
token_offsets = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=False,
)
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=True,
)
LOG.debug(f"Token offsets: {token_offsets_masked}")
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
for i, offset in enumerate(token_offsets_masked):
if offset != IGNORE_TOKEN_ID:
expected_labels[i] = turn_tokens[i]
actual_labels = labels[
start_idx : start_idx + len(token_offsets_masked)
]
assert (
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
for detail in adjusted_train_details:
# Find the token indices that correspond to the character offsets
detail_start = start_idx + next(
i
for i, offset in enumerate(token_offsets)
if offset >= detail["begin_offset"]
)
detail_end = start_idx + next(
(
i
for i, offset in enumerate(token_offsets)
if offset > detail["end_offset"]
),
len(token_offsets),
)
detail_text = turn["value"][
detail["begin_offset"] : detail["end_offset"] + 1
]
detail_labels = labels[detail_start:detail_end]
detail_input_ids = input_ids[detail_start:detail_end]
LOG.debug(
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
)
LOG.debug(f"Detail input_ids: {detail_input_ids}")
LOG.debug(f"Detail labels: {detail_labels}")
LOG.debug(
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
)
LOG.debug(
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
)
if detail["train"]:
assert all(
label != IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
should_train = turn.get("train", False)
turn_labels = labels[start_idx:end_idx]
LOG.debug(f"Should train: {should_train}")
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
LOG.debug(
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
)
if should_train:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be set\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
else:
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
LOG.debug(
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
f"start_idx: {start_idx}, end_idx: {end_idx}, "
f"labels: {labels[start_idx:end_idx]}"
)
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
class TestAssistantChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load(
llama3_tokenizer,
DictDefault(
@@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3:
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
LOG.info("Testing llama-3 with assistant dataset")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3:
"system": ["system"],
},
),
llama3_tokenizer,
False,
512,
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
assert input_ids == [
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
@@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3:
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
roles={
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
prompt_tokens = strategy.prompter.build_prompt(
assistant_dataset[0]["messages"], False
)
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
LOG.debug(f"Generated prompt: {prompt}")
res = strategy.tokenize_prompt(assistant_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# fmt: off
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert labels == expected_labels, (
f"Labels mismatch:\n"
f"Expected: {expected_labels}\n"
f"Actual: {labels}\n"
f"Input IDs: {input_ids}\n"
)
class TestSharegptChatTemplateLlama3:
@@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3:
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
llama3_tokenizer,
False,
512,
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["gpt"],
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
assert input_ids == [
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["human"],
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["system", "human"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 9125, 128007,
271, 2675, 527, 459, 15592, 18328, 13, 128009,
128006, 882, 128007, # user header
271, 9906, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 13347, 1070, 0, 128009, # assistant response eot
128006, 882, 128007,
271, 4438, 527, 499, 30, 128009,
128006, 78191, 128007,
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
if __name__ == "__main__":
unittest.main()

View File

@@ -192,6 +192,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
# pylint: disable=duplicate-code
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
# pylint: disable=duplicate-code
assert input_ids == [
128000, # bos
128006, 9125, 128007, # system header