Add flexible configuration options for chat_template dataset training (#1756)

* Add flexible configuration options for chat dataset training

- Introduce roles_to_train parameter to set training labels by role
- Add train_on_eos option to configure training on end-of-sequence tokens
- Implement per-message training configuration in dataset
- Allow fine-grained control over training specific portions of messages
- Add message_field_training and message_field_training_detail settings
- Implement mapping between dataset character offsets and tokenized prompt
- Enhance test suite to cover new functionality

* Fix missing field inits, things weren't working from yaml.

* Add flexible configuration options for chat dataset training

- Introduce roles_to_train parameter to set training labels by role
- Add train_on_eos option to configure training on end-of-sequence tokens
- Implement per-message training configuration in dataset
- Allow fine-grained control over training specific portions of messages
- Add message_field_training and message_field_training_detail settings
- Implement mapping between dataset character offsets and tokenized prompt
- Enhance test suite to cover new functionality

* Fix missing field inits, things weren't working from yaml.

* chore: lint

* Revert test repo back to NousResearch after opening PR to fix the tokenizer_config.json.

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Adam Brusselback
2024-07-28 21:48:57 -04:00
committed by GitHub
parent 94ba93259f
commit 55cc214c76
4 changed files with 1111 additions and 92 deletions

View File

@@ -6,14 +6,16 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
# Configure the logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
class ChatTemplatePrompter(Prompter): class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates""" """Prompter for HF chat templates"""
def __init__( def __init__(
self, self,
@@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter):
max_length=2048, max_length=2048,
message_field_role: str = "from", message_field_role: str = "from",
message_field_content: str = "value", message_field_content: str = "value",
message_field_training: str = "train",
message_field_training_detail: str = "train_detail",
roles: Optional[Dict[str, List[str]]] = None, roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False, drop_system_message: bool = False,
): ):
@@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter):
} }
self.message_field_role = message_field_role self.message_field_role = message_field_role
self.message_field_content = message_field_content self.message_field_content = message_field_content
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.chat_template = chat_template self.chat_template = chat_template
self.max_length = max_length self.max_length = max_length
@@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter):
{ {
"role": self.roles[t[self.message_field_role]], "role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content], "content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
} }
for t in conversation for t in conversation
] ]
@@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter):
chat_template=self.chat_template, chat_template=self.chat_template,
) )
def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
) -> List[int]:
tokenized_output = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
tokens = tokenized_output.tokens()
token_offsets = tokenized_output["offset_mapping"]
LOG.debug(f"Tokenizing text: {text}")
LOG.debug(f"Tokens: {tokens}")
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
for i in range(len(token_offsets) - 1):
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
# Ensure the last token's end offset is set correctly
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
LOG.debug(f"Token offsets: {token_offsets}")
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
result = [IGNORE_TOKEN_ID] * len(token_offsets)
# Adjust train_details to align with token boundaries
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
for idx, (start, end) in enumerate(token_offsets):
for detail in adjusted_train_details:
# Check if the token is completely within the detail's range
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
if detail["train"] or not mask_untrainable:
result[idx] = start
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
else:
LOG.debug(
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
)
elif start < detail["end_offset"] and end > detail["begin_offset"]:
# Token partially overlaps with detail, always mark as non-trainable
LOG.debug(
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
)
LOG.debug(f"Final result: {result}")
return result
def adjust_train_details(
self, train_details: List[Dict], token_offsets: List[tuple]
) -> List[Dict]:
adjusted_details = []
for detail in train_details:
begin_offset = detail["begin_offset"]
end_offset = detail["end_offset"]
# Find the first token that starts after or at the begin_offset
begin_token = next(
(
i
for i, (t_start, t_end) in enumerate(token_offsets)
if t_start >= begin_offset
),
len(token_offsets),
)
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
begin_token -= 1
# Find the last token that ends before or at the end_offset
end_token = next(
(
i
for i in range(len(token_offsets) - 1, -1, -1)
if token_offsets[i][1] <= end_offset
),
-1,
)
if (
end_token < len(token_offsets) - 1
and token_offsets[end_token + 1][0] < end_offset
):
end_token += 1
if begin_token <= end_token:
adjusted_begin = token_offsets[begin_token][0]
adjusted_end = token_offsets[end_token][1]
if adjusted_begin != begin_offset or adjusted_end != end_offset:
LOG.warning(
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
)
adjusted_details.append(
{
"begin_offset": adjusted_begin,
"end_offset": adjusted_end,
"train": detail["train"],
}
)
else:
LOG.warning(
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
)
return adjusted_details
class ChatTemplateStrategy(PromptTokenizingStrategy): class ChatTemplateStrategy(PromptTokenizingStrategy):
""" """
@@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
_messages = "conversations" _messages = "conversations"
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos="last",
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
@property @property
def messages(self): def messages(self):
return self._messages return self._messages
@@ -79,62 +201,169 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._messages = messages self._messages = messages
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt) turns = prompt[self.messages]
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
if not self.train_on_inputs: last_eos_idx = -1
user_prompt_len = len(prompt_ids) for index, turn in enumerate(turns):
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] role = turn.get(self.prompter.message_field_role)
else: content = turn.get(self.prompter.message_field_content)
labels = input_ids train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
tokenized_prompt = { LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = (
train_turn
if train_turn is not None
else bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
LOG.debug(f"Should train: {should_train}")
turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
input_ids
):
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
if eos_idx == turn_end_idx:
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
LOG.debug(f"Final labels: {labels}")
return {
"input_ids": input_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
"attention_mask": [1] * len(input_ids), "attention_mask": [1] * len(input_ids),
} }
return tokenized_prompt def find_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1
def find_turn(self, conversation_ids, turn, turn_content):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else:
end_idx = -1
return start_idx, end_idx
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt[self.messages] return prompt[self.messages]
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = ( ds_cfg = ds_cfg or {}
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
) prompter_params = {
message_field_role = ( "tokenizer": tokenizer,
ds_cfg["message_field_role"] "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
if ds_cfg and "message_field_role" in ds_cfg "message_field_role": ds_cfg.get("message_field_role", "from"),
else "from" "message_field_content": ds_cfg.get("message_field_content", "value"),
) "message_field_training": ds_cfg.get("message_field_training", "training"),
message_field_content = ( "message_field_training_detail": ds_cfg.get(
ds_cfg["message_field_content"] "message_field_training_detail", "train_detail"
if ds_cfg and "message_field_content" in ds_cfg ),
else "value" "roles": ds_cfg.get("roles"),
) "drop_system_message": ds_cfg.get("drop_system_message", False),
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None }
drop_system_message = (
ds_cfg["drop_system_message"] strategy_params = {
if ds_cfg and "drop_system_message" in ds_cfg "train_on_inputs": cfg.train_on_inputs,
else False "sequence_len": cfg.sequence_len,
) "roles_to_train": ds_cfg.get("roles_to_train"),
"train_on_eos": ds_cfg.get("train_on_eos", "last"),
}
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
drop_system_message=drop_system_message,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"] strategy.messages = ds_cfg["field_messages"]
return strategy return strategy

View File

@@ -116,6 +116,10 @@ class SFTDataset(BaseModel):
field_messages: Optional[str] = None field_messages: Optional[str] = None
message_field_role: Optional[str] = None message_field_role: Optional[str] = None
message_field_content: Optional[str] = None message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None drop_system_message: Optional[bool] = None

View File

@@ -2,6 +2,7 @@
tests for chat_template prompt strategy tests for chat_template prompt strategy
""" """
import logging
import unittest import unittest
import pytest import pytest
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy, ChatTemplateStrategy,
load, load,
) )
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@pytest.fixture(name="assistant_dataset") @pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset(): def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list( return Dataset.from_list(
[ [
{ {
"messages": [ "messages": [
{ {"role": "user", "content": "hello"},
"role": "user", {"role": "assistant", "content": "hello"},
"content": "hello", {"role": "user", "content": "goodbye"},
}, {"role": "assistant", "content": "goodbye"},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
{
"role": "assistant",
"content": "goodbye",
},
] ]
} }
] ]
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
[ [
{ {
"conversations": [ "conversations": [
{ {"from": "human", "value": "hello"},
"from": "human", {"from": "gpt", "value": "hello"},
"value": "hello", {"from": "human", "value": "goodbye"},
}, {"from": "gpt", "value": "goodbye"},
{ ]
"from": "gpt", }
"value": "hello", ]
}, )
{
"from": "human",
"value": "goodbye", @pytest.fixture(name="basic_dataset")
}, def fixture_basic_dataset():
{ # pylint: disable=duplicate-code
"from": "gpt", return Dataset.from_list(
"value": "goodbye", [
}, {
"conversations": [
{"from": "system", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "assistant", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "assistant", "value": "I'm doing well, thank you!"},
] ]
} }
] ]
@@ -77,19 +75,611 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="llama3_tokenizer") @pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer(): def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer return tokenizer
class TestChatTemplateConfigurations:
"""
Test class for various configurations of ChatTemplateStrategy.
"""
@staticmethod
def find_sublist(full_list, sub_list):
token_count = len(sub_list)
for index in range(len(full_list) - token_count + 1):
if full_list[index : index + token_count] == sub_list:
return index
return -1
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Check the behavior of human inputs
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
labeled = all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
)
LOG.debug(
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
)
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
# Verify that human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
LOG.debug(
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(input_ids)]
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that only assistant responses are labeled
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that all responses are labeled (except for special tokens)
all_responses = [
"Hello",
"Hi there!",
"How are you?",
"I'm doing well, thank you!",
]
for response in all_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
LOG.debug(
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
)
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
train_on_eos="none", # Add this line
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
# Verify that no labels are set when roles_to_train is empty
LOG.debug("Full labels: %s", labels)
assert all(
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="all",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="turn",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in assistant_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find '{response}' in input_ids"
eos_idx = start_idx + len(response_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token after assistant response '{response}' to be labeled"
# Check that EOS tokens after human inputs are not labeled
human_inputs = ["Hello", "How are you?"]
for input_text in human_inputs:
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, input_ids)
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
eos_idx = start_idx + len(input_ids)
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token after human input '{input_text}' to not be labeled"
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="last",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
last_eos_idx = eos_indices[-1]
# Check that only the last EOS token is labeled
for idx in eos_indices[:-1]:
assert (
labels[idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {idx} to not be labeled"
assert (
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="none",
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = llama3_tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with drop_system_message=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
assert (
self.find_sublist(input_ids, system_ids) == -1
), "Expected system message to be dropped"
def test_custom_roles(self, llama3_tokenizer):
LOG.info("Testing with custom roles mapping")
custom_roles = {
"user": ["human", "user"],
"assistant": ["ai", "assistant"],
"system": ["context"],
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["ai"],
)
# Create a new dataset with modified role names
modified_conversations = [
{"from": "context", "value": "You are an AI assistant."},
{"from": "human", "value": "Hello"},
{"from": "ai", "value": "Hi there!"},
{"from": "human", "value": "How are you?"},
{"from": "ai", "value": "I'm doing well, thank you!"},
]
modified_dataset = Dataset.from_dict(
{"conversations": [modified_conversations]}
)
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Check if AI responses are labeled correctly
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
for response in ai_responses:
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, response_ids)
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
assert all(
label != IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(response_ids)]
), f"Expected labels for AI response '{response}' to be set"
# Check if human messages are not labeled
human_messages = ["Hello", "How are you?"]
for message in human_messages:
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
start_idx = self.find_sublist(input_ids, message_ids)
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
assert all(
label == IGNORE_TOKEN_ID
for label in labels[start_idx : start_idx + len(message_ids)]
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
def test_message_field_training(self, llama3_tokenizer):
LOG.info("Testing with message_field_training")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
)
# Create a new dataset with the train and train_detail fields
modified_conversation = [
{"from": "system", "value": "You are an AI assistant.", "train": False},
{"from": "human", "value": "Hello", "train": False},
{"from": "assistant", "value": "Hello", "train": True},
{"from": "human", "value": "How are you?", "train": True},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": False},
{"begin_offset": 9, "end_offset": 18, "train": True},
{"begin_offset": 19, "end_offset": 30, "train": False},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": False,
},
{"from": "assistant", "value": "Hi there!", "train": True},
]
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
res = strategy.tokenize_prompt(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Function to find all occurrences of a sublist
def find_all_sublists(full_list, sub_list):
indices = []
for index in range(len(full_list) - len(sub_list) + 1):
if full_list[index : index + len(sub_list)] == sub_list:
indices.append(index)
return indices
# Keep track of which occurrences we've processed
processed_occurrences = {}
# Check if messages are labeled correctly based on train or train_detail
for i, turn in enumerate(modified_conversation):
turn_tokens = llama3_tokenizer.encode(
turn["value"], add_special_tokens=False
)
occurrences = find_all_sublists(input_ids, turn_tokens)
turn_key = turn["value"]
if turn_key not in processed_occurrences:
processed_occurrences[turn_key] = 0
current_occurrence = processed_occurrences[turn_key]
if current_occurrence >= len(occurrences):
assert (
False
), f"Not enough occurrences found for message: {turn['value']}"
start_idx = occurrences[current_occurrence]
processed_occurrences[turn_key] += 1
end_idx = start_idx + len(turn_tokens)
LOG.debug(
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
)
if "train_detail" in turn:
# Get token offsets
tokenized_output = llama3_tokenizer(
turn["value"], return_offsets_mapping=True, add_special_tokens=False
)
token_offsets = tokenized_output["offset_mapping"]
# Adjust token offsets as done in the implementation
for i in range(len(token_offsets) - 1):
token_offsets[i] = (
token_offsets[i][0],
token_offsets[i + 1][0] - 1,
)
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
# Adjust train_details
adjusted_train_details = strategy.prompter.adjust_train_details(
turn["train_detail"], token_offsets
)
LOG.debug(f"Original train_details: {turn['train_detail']}")
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
# Handle train_detail
token_offsets = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=False,
)
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=True,
)
LOG.debug(f"Token offsets: {token_offsets_masked}")
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
for i, offset in enumerate(token_offsets_masked):
if offset != IGNORE_TOKEN_ID:
expected_labels[i] = turn_tokens[i]
actual_labels = labels[
start_idx : start_idx + len(token_offsets_masked)
]
assert (
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
for detail in adjusted_train_details:
# Find the token indices that correspond to the character offsets
detail_start = start_idx + next(
i
for i, offset in enumerate(token_offsets)
if offset >= detail["begin_offset"]
)
detail_end = start_idx + next(
(
i
for i, offset in enumerate(token_offsets)
if offset > detail["end_offset"]
),
len(token_offsets),
)
detail_text = turn["value"][
detail["begin_offset"] : detail["end_offset"] + 1
]
detail_labels = labels[detail_start:detail_end]
detail_input_ids = input_ids[detail_start:detail_end]
LOG.debug(
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
)
LOG.debug(f"Detail input_ids: {detail_input_ids}")
LOG.debug(f"Detail labels: {detail_labels}")
LOG.debug(
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
)
LOG.debug(
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
)
if detail["train"]:
assert all(
label != IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in detail_labels
), (
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
f"InputIDs: {detail_input_ids}, "
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
)
else:
should_train = turn.get("train", False)
turn_labels = labels[start_idx:end_idx]
LOG.debug(f"Should train: {should_train}")
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
LOG.debug(
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
)
if should_train:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be set\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
else:
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
f"InputIDs: {input_ids[start_idx:end_idx]}, "
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
)
LOG.debug(
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
f"start_idx: {start_idx}, end_idx: {end_idx}, "
f"labels: {labels[start_idx:end_idx]}"
)
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
class TestAssistantChatTemplateLlama3: class TestAssistantChatTemplateLlama3:
""" """
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
""" """
def test_llama3_load(self, llama3_tokenizer, assistant_dataset): def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load( strategy = load(
llama3_tokenizer, llama3_tokenizer,
DictDefault( DictDefault(
@@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3:
res = strategy.tokenize_prompt(assistant_dataset[0]) res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"] input_ids = res["input_ids"]
# fmt: off # fmt: off
assert input_ids == [ expected_input_ids = [
128000, # bos 128000, # bos
128006, 882, 128007, # user header 128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot 271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header 128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot 271, 15339, 128009, # assistant response eot
128006, 882, 128007, 128006, 882, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
128006, 78191, 128007, 128006, 78191, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3(self, llama3_tokenizer, assistant_dataset): def test_llama3(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code LOG.info("Testing llama-3 with assistant dataset")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
@@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3:
"system": ["system"], "system": ["system"],
}, },
), ),
llama3_tokenizer, tokenizer=llama3_tokenizer,
False, train_on_inputs=False,
512, sequence_len=512,
roles_to_train=["assistant"],
) )
strategy.messages = "messages" strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0]) res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"] input_ids = res["input_ids"]
# fmt: off # fmt: off
assert input_ids == [ expected_input_ids = [
128000, # bos 128000, # bos
128006, 882, 128007, # user header 128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot 271, 15339, 128009, # user prompt eot
@@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3:
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
roles={
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
prompt_tokens = strategy.prompter.build_prompt(
assistant_dataset[0]["messages"], False
)
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
LOG.debug(f"Generated prompt: {prompt}")
res = strategy.tokenize_prompt(assistant_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# fmt: off
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert labels == expected_labels, (
f"Labels mismatch:\n"
f"Expected: {expected_labels}\n"
f"Actual: {labels}\n"
f"Input IDs: {input_ids}\n"
)
class TestSharegptChatTemplateLlama3: class TestSharegptChatTemplateLlama3:
@@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3:
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy. Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
""" """
def test_llama3(self, llama3_tokenizer, sharegpt_dataset): def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
llama3_tokenizer, tokenizer=llama3_tokenizer,
False, train_on_inputs=False,
512, train_on_eos="none",
sequence_len=512,
roles_to_train=["gpt"],
) )
res = strategy.tokenize_prompt(sharegpt_dataset[0]) res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"] input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off # fmt: off
assert input_ids == [ expected_input_ids = [
128000, # bos 128000, # bos
128006, 882, 128007, # user header 128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot 271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header 128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot 271, 15339, 128009, # assistant response eot
128006, 882, 128007, 128006, 882, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
128006, 78191, 128007, 128006, 78191, 128007,
271, 19045, 29474, 128009, 271, 19045, 29474, 128009,
] ]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["human"],
)
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
sequence_len=512,
roles_to_train=["system", "human"],
)
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 9125, 128007,
271, 2675, 527, 459, 15592, 18328, 13, 128009,
128006, 882, 128007, # user header
271, 9906, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 13347, 1070, 0, 128009, # assistant response eot
128006, 882, 128007,
271, 4438, 527, 499, 30, 128009,
128006, 78191, 128007,
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
]
expected_labels = [
IGNORE_TOKEN_ID, # bos
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -192,6 +192,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"] input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off # fmt: off
# pylint: disable=duplicate-code
assert input_ids == [ assert input_ids == [
128000, # bos 128000, # bos
128006, 9125, 128007, # system header 128006, 9125, 128007, # system header
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
input_ids = dataset_wrapper[0]["input_ids"] input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off # fmt: off
# pylint: disable=duplicate-code
assert input_ids == [ assert input_ids == [
128000, # bos 128000, # bos
128006, 9125, 128007, # system header 128006, 9125, 128007, # system header