Add flexible configuration options for chat_template dataset training (#1756)
* Add flexible configuration options for chat dataset training - Introduce roles_to_train parameter to set training labels by role - Add train_on_eos option to configure training on end-of-sequence tokens - Implement per-message training configuration in dataset - Allow fine-grained control over training specific portions of messages - Add message_field_training and message_field_training_detail settings - Implement mapping between dataset character offsets and tokenized prompt - Enhance test suite to cover new functionality * Fix missing field inits, things weren't working from yaml. * Add flexible configuration options for chat dataset training - Introduce roles_to_train parameter to set training labels by role - Add train_on_eos option to configure training on end-of-sequence tokens - Implement per-message training configuration in dataset - Allow fine-grained control over training specific portions of messages - Add message_field_training and message_field_training_detail settings - Implement mapping between dataset character offsets and tokenized prompt - Enhance test suite to cover new functionality * Fix missing field inits, things weren't working from yaml. * chore: lint * Revert test repo back to NousResearch after opening PR to fix the tokenizer_config.json. --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -6,14 +6,16 @@ import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class ChatTemplatePrompter(Prompter):
|
||||
"""prompter for HF chat templates"""
|
||||
"""Prompter for HF chat templates"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter):
|
||||
max_length=2048,
|
||||
message_field_role: str = "from",
|
||||
message_field_content: str = "value",
|
||||
message_field_training: str = "train",
|
||||
message_field_training_detail: str = "train_detail",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
@@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter):
|
||||
}
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self.message_field_training = message_field_training
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
@@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
{
|
||||
"role": self.roles[t[self.message_field_role]],
|
||||
"content": t[self.message_field_content],
|
||||
"training": t.get(self.message_field_training, None),
|
||||
}
|
||||
for t in conversation
|
||||
]
|
||||
@@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter):
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
def get_offsets_for_train_detail(
|
||||
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
|
||||
) -> List[int]:
|
||||
tokenized_output = self.tokenizer(
|
||||
text, return_offsets_mapping=True, add_special_tokens=False
|
||||
)
|
||||
tokens = tokenized_output.tokens()
|
||||
token_offsets = tokenized_output["offset_mapping"]
|
||||
|
||||
LOG.debug(f"Tokenizing text: {text}")
|
||||
LOG.debug(f"Tokens: {tokens}")
|
||||
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
|
||||
for i in range(len(token_offsets) - 1):
|
||||
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
|
||||
# Ensure the last token's end offset is set correctly
|
||||
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
|
||||
LOG.debug(f"Token offsets: {token_offsets}")
|
||||
|
||||
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
|
||||
result = [IGNORE_TOKEN_ID] * len(token_offsets)
|
||||
|
||||
# Adjust train_details to align with token boundaries
|
||||
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
|
||||
|
||||
for idx, (start, end) in enumerate(token_offsets):
|
||||
for detail in adjusted_train_details:
|
||||
# Check if the token is completely within the detail's range
|
||||
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
|
||||
if detail["train"] or not mask_untrainable:
|
||||
result[idx] = start
|
||||
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
|
||||
else:
|
||||
LOG.debug(
|
||||
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
|
||||
)
|
||||
elif start < detail["end_offset"] and end > detail["begin_offset"]:
|
||||
# Token partially overlaps with detail, always mark as non-trainable
|
||||
LOG.debug(
|
||||
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
|
||||
)
|
||||
|
||||
LOG.debug(f"Final result: {result}")
|
||||
return result
|
||||
|
||||
def adjust_train_details(
|
||||
self, train_details: List[Dict], token_offsets: List[tuple]
|
||||
) -> List[Dict]:
|
||||
adjusted_details = []
|
||||
for detail in train_details:
|
||||
begin_offset = detail["begin_offset"]
|
||||
end_offset = detail["end_offset"]
|
||||
|
||||
# Find the first token that starts after or at the begin_offset
|
||||
begin_token = next(
|
||||
(
|
||||
i
|
||||
for i, (t_start, t_end) in enumerate(token_offsets)
|
||||
if t_start >= begin_offset
|
||||
),
|
||||
len(token_offsets),
|
||||
)
|
||||
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
|
||||
begin_token -= 1
|
||||
|
||||
# Find the last token that ends before or at the end_offset
|
||||
end_token = next(
|
||||
(
|
||||
i
|
||||
for i in range(len(token_offsets) - 1, -1, -1)
|
||||
if token_offsets[i][1] <= end_offset
|
||||
),
|
||||
-1,
|
||||
)
|
||||
if (
|
||||
end_token < len(token_offsets) - 1
|
||||
and token_offsets[end_token + 1][0] < end_offset
|
||||
):
|
||||
end_token += 1
|
||||
|
||||
if begin_token <= end_token:
|
||||
adjusted_begin = token_offsets[begin_token][0]
|
||||
adjusted_end = token_offsets[end_token][1]
|
||||
|
||||
if adjusted_begin != begin_offset or adjusted_end != end_offset:
|
||||
LOG.warning(
|
||||
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
|
||||
)
|
||||
|
||||
adjusted_details.append(
|
||||
{
|
||||
"begin_offset": adjusted_begin,
|
||||
"end_offset": adjusted_end,
|
||||
"train": detail["train"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
|
||||
)
|
||||
|
||||
return adjusted_details
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
@@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
_messages = "conversations"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos="last",
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
||||
self.train_on_eos = train_on_eos
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
@@ -79,62 +201,169 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self._messages = messages
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
||||
turns = prompt[self.messages]
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
last_eos_idx = -1
|
||||
for index, turn in enumerate(turns):
|
||||
role = turn.get(self.prompter.message_field_role)
|
||||
content = turn.get(self.prompter.message_field_content)
|
||||
train_turn = turn.get(self.prompter.message_field_training)
|
||||
train_detail = turn.get(self.prompter.message_field_training_detail)
|
||||
|
||||
tokenized_prompt = {
|
||||
LOG.debug(
|
||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||
)
|
||||
|
||||
should_train = (
|
||||
train_turn
|
||||
if train_turn is not None
|
||||
else bool(train_detail is not None)
|
||||
if train_detail is not None
|
||||
else self.train_on_inputs or role in self.roles_to_train
|
||||
)
|
||||
|
||||
LOG.debug(f"Should train: {should_train}")
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||
)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||
if train_detail:
|
||||
token_offsets = self.prompter.get_offsets_for_train_detail(
|
||||
content, train_detail
|
||||
)
|
||||
LOG.debug(f"Token offsets: {token_offsets}")
|
||||
for i, offset in enumerate(token_offsets):
|
||||
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
|
||||
input_ids
|
||||
):
|
||||
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
|
||||
LOG.debug(
|
||||
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
||||
)
|
||||
else:
|
||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||
turn_start_idx:turn_end_idx
|
||||
]
|
||||
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
|
||||
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
# Handle EOS token
|
||||
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
|
||||
if eos_idx == turn_end_idx:
|
||||
last_eos_idx = eos_idx
|
||||
if self.train_on_eos == "all" or (
|
||||
self.train_on_eos == "turn" and should_train
|
||||
):
|
||||
labels[eos_idx] = input_ids[eos_idx]
|
||||
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
||||
else:
|
||||
LOG.debug(
|
||||
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
||||
)
|
||||
|
||||
# Handle 'last' option for train_on_eos
|
||||
if self.train_on_eos == "last" and last_eos_idx != -1:
|
||||
labels[last_eos_idx] = input_ids[last_eos_idx]
|
||||
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
||||
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
def find_eos_token(self, input_ids, start_idx):
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
for i in range(start_idx, len(input_ids)):
|
||||
if input_ids[i] == eos_token_id:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, conversation_ids, turn, turn_content):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
|
||||
Args:
|
||||
conversation_ids (list[int]): Token IDs representing the conversation.
|
||||
turn (int): The turn number to locate (based on EOS tokens).
|
||||
turn_content (str): String containing the content of the turn.
|
||||
|
||||
Returns:
|
||||
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
|
||||
Returns (-1, -1) if the turn content is not found.
|
||||
"""
|
||||
content = turn_content.get(self.prompter.message_field_content, "")
|
||||
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
eos_count = 0
|
||||
start_search_idx = 0
|
||||
|
||||
# Locate the starting index after the specified number of EOS tokens
|
||||
for i, token_id in enumerate(conversation_ids):
|
||||
if token_id == eos_token_id:
|
||||
eos_count += 1
|
||||
if eos_count == turn:
|
||||
start_search_idx = (
|
||||
i + 1
|
||||
) # Start searching after the specified turn's EOS token
|
||||
break
|
||||
|
||||
# Find the start index of the content within the conversation
|
||||
start_idx = -1
|
||||
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
|
||||
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||
start_idx = i
|
||||
break
|
||||
|
||||
if start_idx != -1:
|
||||
end_idx = start_idx + len(content_ids)
|
||||
else:
|
||||
end_idx = -1
|
||||
|
||||
return start_idx, end_idx
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt[self.messages]
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
chat_template = (
|
||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||
)
|
||||
message_field_role = (
|
||||
ds_cfg["message_field_role"]
|
||||
if ds_cfg and "message_field_role" in ds_cfg
|
||||
else "from"
|
||||
)
|
||||
message_field_content = (
|
||||
ds_cfg["message_field_content"]
|
||||
if ds_cfg and "message_field_content" in ds_cfg
|
||||
else "value"
|
||||
)
|
||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||
drop_system_message = (
|
||||
ds_cfg["drop_system_message"]
|
||||
if ds_cfg and "drop_system_message" in ds_cfg
|
||||
else False
|
||||
)
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", "train_detail"
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
}
|
||||
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train"),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "last"),
|
||||
}
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_templates(chat_template),
|
||||
message_field_role=message_field_role,
|
||||
message_field_content=message_field_content,
|
||||
roles=roles,
|
||||
drop_system_message=drop_system_message,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
@@ -116,6 +116,10 @@ class SFTDataset(BaseModel):
|
||||
field_messages: Optional[str] = None
|
||||
message_field_role: Optional[str] = None
|
||||
message_field_content: Optional[str] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
drop_system_message: Optional[bool] = None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplateStrategy,
|
||||
load,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "goodbye",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
{"role": "user", "content": "goodbye"},
|
||||
{"role": "assistant", "content": "goodbye"},
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "hello",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "goodbye",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "goodbye",
|
||||
},
|
||||
{"from": "human", "value": "hello"},
|
||||
{"from": "gpt", "value": "hello"},
|
||||
{"from": "human", "value": "goodbye"},
|
||||
{"from": "gpt", "value": "goodbye"},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="basic_dataset")
|
||||
def fixture_basic_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are an AI assistant."},
|
||||
{"from": "human", "value": "Hello"},
|
||||
{"from": "assistant", "value": "Hi there!"},
|
||||
{"from": "human", "value": "How are you?"},
|
||||
{"from": "assistant", "value": "I'm doing well, thank you!"},
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -77,19 +75,611 @@ def fixture_sharegpt_dataset():
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
def fixture_llama3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
tokenizer.eos_token = "<|eot_id|>"
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestChatTemplateConfigurations:
|
||||
"""
|
||||
Test class for various configurations of ChatTemplateStrategy.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def find_sublist(full_list, sub_list):
|
||||
token_count = len(sub_list)
|
||||
for index in range(len(full_list) - token_count + 1):
|
||||
if full_list[index : index + token_count] == sub_list:
|
||||
return index
|
||||
return -1
|
||||
|
||||
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Verify that assistant responses are labeled
|
||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||
for response in assistant_responses:
|
||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, response_ids)
|
||||
LOG.debug(
|
||||
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||
)
|
||||
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||
|
||||
# Check the behavior of human inputs
|
||||
human_inputs = ["Hello", "How are you?"]
|
||||
for input_text in human_inputs:
|
||||
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, input_ids)
|
||||
labeled = all(
|
||||
label != IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||
)
|
||||
LOG.debug(
|
||||
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
|
||||
)
|
||||
|
||||
LOG.debug("Full labels: %s", labels)
|
||||
LOG.debug("Full input_ids: %s", input_ids)
|
||||
|
||||
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_inputs=False")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Verify that only assistant responses are labeled
|
||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||
for response in assistant_responses:
|
||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, response_ids)
|
||||
LOG.debug(
|
||||
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||
)
|
||||
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||
|
||||
# Verify that human inputs are not labeled
|
||||
human_inputs = ["Hello", "How are you?"]
|
||||
for input_text in human_inputs:
|
||||
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, input_ids)
|
||||
LOG.debug(
|
||||
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
|
||||
)
|
||||
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
|
||||
|
||||
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing roles_to_train with assistant only")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Verify that only assistant responses are labeled
|
||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||
for response in assistant_responses:
|
||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, response_ids)
|
||||
LOG.debug(
|
||||
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||
)
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||
|
||||
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing roles_to_train with all roles")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["human", "assistant"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Verify that all responses are labeled (except for special tokens)
|
||||
all_responses = [
|
||||
"Hello",
|
||||
"Hi there!",
|
||||
"How are you?",
|
||||
"I'm doing well, thank you!",
|
||||
]
|
||||
for response in all_responses:
|
||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, response_ids)
|
||||
LOG.debug(
|
||||
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||
)
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||
|
||||
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with empty roles_to_train")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=[],
|
||||
train_on_eos="none", # Add this line
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
|
||||
# Verify that no labels are set when roles_to_train is empty
|
||||
LOG.debug("Full labels: %s", labels)
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels
|
||||
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||
|
||||
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='all'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="all",
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
eos_token_id = llama3_tokenizer.eos_token_id
|
||||
eos_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
for eos_idx in eos_indices:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {eos_idx} to be labeled"
|
||||
|
||||
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='turn'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="turn",
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
eos_token_id = llama3_tokenizer.eos_token_id
|
||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||
|
||||
for response in assistant_responses:
|
||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, response_ids)
|
||||
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||
|
||||
eos_idx = start_idx + len(response_ids)
|
||||
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||
eos_idx += 1
|
||||
|
||||
assert eos_idx < len(
|
||||
input_ids
|
||||
), f"Could not find EOS token after '{response}'"
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||
|
||||
# Check that EOS tokens after human inputs are not labeled
|
||||
human_inputs = ["Hello", "How are you?"]
|
||||
for input_text in human_inputs:
|
||||
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, input_ids)
|
||||
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||
|
||||
eos_idx = start_idx + len(input_ids)
|
||||
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||
eos_idx += 1
|
||||
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token after human input '{input_text}' to not be labeled"
|
||||
|
||||
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='last'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="last",
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
eos_token_id = llama3_tokenizer.eos_token_id
|
||||
eos_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
last_eos_idx = eos_indices[-1]
|
||||
|
||||
# Check that only the last EOS token is labeled
|
||||
for idx in eos_indices[:-1]:
|
||||
assert (
|
||||
labels[idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {idx} to not be labeled"
|
||||
assert (
|
||||
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||
|
||||
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='none'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="none",
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
eos_token_id = llama3_tokenizer.eos_token_id
|
||||
eos_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
for eos_idx in eos_indices:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||
|
||||
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with drop_system_message=True")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Check if system message is not present in input_ids
|
||||
system_message = "You are an AI assistant."
|
||||
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
|
||||
assert (
|
||||
self.find_sublist(input_ids, system_ids) == -1
|
||||
), "Expected system message to be dropped"
|
||||
|
||||
def test_custom_roles(self, llama3_tokenizer):
|
||||
LOG.info("Testing with custom roles mapping")
|
||||
custom_roles = {
|
||||
"user": ["human", "user"],
|
||||
"assistant": ["ai", "assistant"],
|
||||
"system": ["context"],
|
||||
}
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["ai"],
|
||||
)
|
||||
|
||||
# Create a new dataset with modified role names
|
||||
modified_conversations = [
|
||||
{"from": "context", "value": "You are an AI assistant."},
|
||||
{"from": "human", "value": "Hello"},
|
||||
{"from": "ai", "value": "Hi there!"},
|
||||
{"from": "human", "value": "How are you?"},
|
||||
{"from": "ai", "value": "I'm doing well, thank you!"},
|
||||
]
|
||||
|
||||
modified_dataset = Dataset.from_dict(
|
||||
{"conversations": [modified_conversations]}
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Check if AI responses are labeled correctly
|
||||
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||
for response in ai_responses:
|
||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, response_ids)
|
||||
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||
), f"Expected labels for AI response '{response}' to be set"
|
||||
|
||||
# Check if human messages are not labeled
|
||||
human_messages = ["Hello", "How are you?"]
|
||||
for message in human_messages:
|
||||
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
|
||||
start_idx = self.find_sublist(input_ids, message_ids)
|
||||
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID
|
||||
for label in labels[start_idx : start_idx + len(message_ids)]
|
||||
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
|
||||
|
||||
def test_message_field_training(self, llama3_tokenizer):
|
||||
LOG.info("Testing with message_field_training")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
message_field_training="train",
|
||||
message_field_training_detail="train_detail",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=[],
|
||||
)
|
||||
|
||||
# Create a new dataset with the train and train_detail fields
|
||||
modified_conversation = [
|
||||
{"from": "system", "value": "You are an AI assistant.", "train": False},
|
||||
{"from": "human", "value": "Hello", "train": False},
|
||||
{"from": "assistant", "value": "Hello", "train": True},
|
||||
{"from": "human", "value": "How are you?", "train": True},
|
||||
{
|
||||
"from": "assistant",
|
||||
"value": "I'm doing very well, thank you!",
|
||||
"train_detail": [
|
||||
{"begin_offset": 0, "end_offset": 8, "train": False},
|
||||
{"begin_offset": 9, "end_offset": 18, "train": True},
|
||||
{"begin_offset": 19, "end_offset": 30, "train": False},
|
||||
],
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "I'm doing very well, thank you!",
|
||||
"train": False,
|
||||
},
|
||||
{"from": "assistant", "value": "Hi there!", "train": True},
|
||||
]
|
||||
|
||||
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
|
||||
|
||||
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Function to find all occurrences of a sublist
|
||||
def find_all_sublists(full_list, sub_list):
|
||||
indices = []
|
||||
for index in range(len(full_list) - len(sub_list) + 1):
|
||||
if full_list[index : index + len(sub_list)] == sub_list:
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
# Keep track of which occurrences we've processed
|
||||
processed_occurrences = {}
|
||||
# Check if messages are labeled correctly based on train or train_detail
|
||||
for i, turn in enumerate(modified_conversation):
|
||||
turn_tokens = llama3_tokenizer.encode(
|
||||
turn["value"], add_special_tokens=False
|
||||
)
|
||||
occurrences = find_all_sublists(input_ids, turn_tokens)
|
||||
turn_key = turn["value"]
|
||||
if turn_key not in processed_occurrences:
|
||||
processed_occurrences[turn_key] = 0
|
||||
current_occurrence = processed_occurrences[turn_key]
|
||||
|
||||
if current_occurrence >= len(occurrences):
|
||||
assert (
|
||||
False
|
||||
), f"Not enough occurrences found for message: {turn['value']}"
|
||||
|
||||
start_idx = occurrences[current_occurrence]
|
||||
processed_occurrences[turn_key] += 1
|
||||
end_idx = start_idx + len(turn_tokens)
|
||||
|
||||
LOG.debug(
|
||||
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
|
||||
)
|
||||
|
||||
if "train_detail" in turn:
|
||||
# Get token offsets
|
||||
tokenized_output = llama3_tokenizer(
|
||||
turn["value"], return_offsets_mapping=True, add_special_tokens=False
|
||||
)
|
||||
token_offsets = tokenized_output["offset_mapping"]
|
||||
|
||||
# Adjust token offsets as done in the implementation
|
||||
for i in range(len(token_offsets) - 1):
|
||||
token_offsets[i] = (
|
||||
token_offsets[i][0],
|
||||
token_offsets[i + 1][0] - 1,
|
||||
)
|
||||
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
|
||||
|
||||
# Adjust train_details
|
||||
adjusted_train_details = strategy.prompter.adjust_train_details(
|
||||
turn["train_detail"], token_offsets
|
||||
)
|
||||
|
||||
LOG.debug(f"Original train_details: {turn['train_detail']}")
|
||||
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
|
||||
|
||||
# Handle train_detail
|
||||
token_offsets = strategy.prompter.get_offsets_for_train_detail(
|
||||
text=turn["value"],
|
||||
train_details=adjusted_train_details,
|
||||
mask_untrainable=False,
|
||||
)
|
||||
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
|
||||
text=turn["value"],
|
||||
train_details=adjusted_train_details,
|
||||
mask_untrainable=True,
|
||||
)
|
||||
LOG.debug(f"Token offsets: {token_offsets_masked}")
|
||||
|
||||
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
|
||||
for i, offset in enumerate(token_offsets_masked):
|
||||
if offset != IGNORE_TOKEN_ID:
|
||||
expected_labels[i] = turn_tokens[i]
|
||||
actual_labels = labels[
|
||||
start_idx : start_idx + len(token_offsets_masked)
|
||||
]
|
||||
assert (
|
||||
actual_labels == expected_labels
|
||||
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||
|
||||
for detail in adjusted_train_details:
|
||||
# Find the token indices that correspond to the character offsets
|
||||
detail_start = start_idx + next(
|
||||
i
|
||||
for i, offset in enumerate(token_offsets)
|
||||
if offset >= detail["begin_offset"]
|
||||
)
|
||||
detail_end = start_idx + next(
|
||||
(
|
||||
i
|
||||
for i, offset in enumerate(token_offsets)
|
||||
if offset > detail["end_offset"]
|
||||
),
|
||||
len(token_offsets),
|
||||
)
|
||||
|
||||
detail_text = turn["value"][
|
||||
detail["begin_offset"] : detail["end_offset"] + 1
|
||||
]
|
||||
detail_labels = labels[detail_start:detail_end]
|
||||
detail_input_ids = input_ids[detail_start:detail_end]
|
||||
|
||||
LOG.debug(
|
||||
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
|
||||
)
|
||||
LOG.debug(f"Detail input_ids: {detail_input_ids}")
|
||||
LOG.debug(f"Detail labels: {detail_labels}")
|
||||
LOG.debug(
|
||||
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
|
||||
)
|
||||
LOG.debug(
|
||||
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
|
||||
)
|
||||
|
||||
if detail["train"]:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in detail_labels
|
||||
), (
|
||||
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
|
||||
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||
f"InputIDs: {detail_input_ids}, "
|
||||
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in detail_labels
|
||||
), (
|
||||
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
|
||||
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||
f"InputIDs: {detail_input_ids}, "
|
||||
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||
)
|
||||
else:
|
||||
should_train = turn.get("train", False)
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
|
||||
LOG.debug(f"Should train: {should_train}")
|
||||
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
|
||||
LOG.debug(f"Turn labels: {turn_labels}")
|
||||
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
|
||||
LOG.debug(
|
||||
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
|
||||
)
|
||||
|
||||
if should_train:
|
||||
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Expected all labels for '{turn['value']}' to be set\n"
|
||||
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||
)
|
||||
else:
|
||||
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
|
||||
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||
)
|
||||
|
||||
LOG.debug(
|
||||
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
|
||||
f"start_idx: {start_idx}, end_idx: {end_idx}, "
|
||||
f"labels: {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
LOG.debug(f"Final input_ids: {input_ids}")
|
||||
|
||||
|
||||
class TestAssistantChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
||||
strategy = load(
|
||||
llama3_tokenizer,
|
||||
DictDefault(
|
||||
@@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3:
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
|
||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
LOG.info("Testing llama-3 with assistant dataset")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3:
|
||||
"system": ["system"],
|
||||
},
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False,
|
||||
512,
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
@@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3:
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
|
||||
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing llama-3 with assistant dataset including training data")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_field_training="training",
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
prompt_tokens = strategy.prompter.build_prompt(
|
||||
assistant_dataset[0]["messages"], False
|
||||
)
|
||||
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
|
||||
LOG.debug(f"Generated prompt: {prompt}")
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
expected_labels = [
|
||||
IGNORE_TOKEN_ID, # bos
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch:\n"
|
||||
f"Expected: {expected_labels}\n"
|
||||
f"Actual: {labels}\n"
|
||||
f"Input IDs: {input_ids}\n"
|
||||
)
|
||||
|
||||
|
||||
class TestSharegptChatTemplateLlama3:
|
||||
@@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3:
|
||||
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_llama3(self, llama3_tokenizer, sharegpt_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
llama3_tokenizer,
|
||||
False,
|
||||
512,
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
sequence_len=512,
|
||||
roles_to_train=["gpt"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
expected_labels = [
|
||||
IGNORE_TOKEN_ID, # bos
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
|
||||
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
sequence_len=512,
|
||||
roles_to_train=["human"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
expected_labels = [
|
||||
IGNORE_TOKEN_ID, # bos
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
|
||||
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
sequence_len=512,
|
||||
roles_to_train=["system", "human"],
|
||||
)
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
128000, # bos
|
||||
128006, 9125, 128007,
|
||||
271, 2675, 527, 459, 15592, 18328, 13, 128009,
|
||||
128006, 882, 128007, # user header
|
||||
271, 9906, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 13347, 1070, 0, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 4438, 527, 499, 30, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
|
||||
]
|
||||
expected_labels = [
|
||||
IGNORE_TOKEN_ID, # bos
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
||||
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -192,6 +192,7 @@ class TestSharegptLlama3:
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
# pylint: disable=duplicate-code
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
# pylint: disable=duplicate-code
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
|
||||
Reference in New Issue
Block a user