Trigger the original tokenization behavior when no advanced turn settings are provided (#1915)
This commit is contained in:
76
examples/phi/lora-3.5.yaml
Normal file
76
examples/phi/lora-3.5.yaml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: microsoft/Phi-3.5-mini-instruct
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: phi_3
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
chat_template: phi_3
|
||||||
|
field_messages: messages
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bfloat16: true
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 4
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -24,8 +24,8 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
max_length=2048,
|
max_length=2048,
|
||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
message_field_content: str = "value",
|
message_field_content: str = "value",
|
||||||
message_field_training: str = "train",
|
message_field_training: Optional[str] = None,
|
||||||
message_field_training_detail: str = "train_detail",
|
message_field_training_detail: Optional[str] = None,
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
@@ -186,7 +186,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
train_on_inputs,
|
train_on_inputs,
|
||||||
sequence_len,
|
sequence_len,
|
||||||
roles_to_train=None,
|
roles_to_train=None,
|
||||||
train_on_eos="last",
|
train_on_eos=None,
|
||||||
):
|
):
|
||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
||||||
@@ -201,6 +201,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
self._messages = messages
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
# Old simple legacy behavior that works reliably.
|
||||||
|
if (
|
||||||
|
not self.roles_to_train
|
||||||
|
and not self.train_on_eos
|
||||||
|
and not self.prompter.message_field_training
|
||||||
|
and not self.prompter.message_field_training_detail
|
||||||
|
):
|
||||||
|
turns = self.get_conversation_thread(prompt)
|
||||||
|
prompt_ids = self.prompter.build_prompt(
|
||||||
|
turns[:-1], add_generation_prompt=True
|
||||||
|
)
|
||||||
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
|
||||||
|
if not self.train_on_inputs:
|
||||||
|
user_prompt_len = len(prompt_ids)
|
||||||
|
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||||
|
else:
|
||||||
|
labels = input_ids
|
||||||
|
|
||||||
|
tokenized_prompt = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
"attention_mask": [1] * len(input_ids),
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenized_prompt
|
||||||
|
LOG.info(self.roles_to_train)
|
||||||
|
LOG.info(self.train_on_eos)
|
||||||
|
LOG.info(self.prompter.message_field_training)
|
||||||
|
LOG.info(self.prompter.message_field_training_detail)
|
||||||
|
|
||||||
turns = prompt[self.messages]
|
turns = prompt[self.messages]
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
@@ -219,9 +250,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
should_train = (
|
should_train = (
|
||||||
train_turn
|
train_turn
|
||||||
if train_turn is not None
|
if train_turn is not None
|
||||||
else bool(train_detail is not None)
|
else (
|
||||||
if train_detail is not None
|
bool(train_detail is not None)
|
||||||
else self.train_on_inputs or role in self.roles_to_train
|
if train_detail is not None
|
||||||
|
else self.train_on_inputs or role in self.roles_to_train
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.debug(f"Should train: {should_train}")
|
LOG.debug(f"Should train: {should_train}")
|
||||||
@@ -344,9 +377,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||||
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||||
"message_field_training_detail": ds_cfg.get(
|
"message_field_training_detail": ds_cfg.get(
|
||||||
"message_field_training_detail", "train_detail"
|
"message_field_training_detail",
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
"roles": ds_cfg.get("roles"),
|
"roles": ds_cfg.get("roles"),
|
||||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
@@ -357,8 +391,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
strategy_params = {
|
strategy_params = {
|
||||||
"train_on_inputs": cfg.train_on_inputs,
|
"train_on_inputs": cfg.train_on_inputs,
|
||||||
"sequence_len": cfg.sequence_len,
|
"sequence_len": cfg.sequence_len,
|
||||||
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
|
"roles_to_train": ds_cfg.get("roles_to_train", []),
|
||||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
"train_on_eos": ds_cfg.get("train_on_eos", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -189,6 +189,7 @@ class ChatTemplate(str, Enum):
|
|||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
jamba = "jamba" # pylint: disable=invalid-name
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|||||||
71
tests/prompt_strategies/conftest.py
Normal file
71
tests/prompt_strategies/conftest.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
shared fixtures for prompt strategies tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="assistant_dataset")
|
||||||
|
def fixture_assistant_dataset():
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hello"},
|
||||||
|
{"role": "user", "content": "goodbye"},
|
||||||
|
{"role": "assistant", "content": "goodbye"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="sharegpt_dataset")
|
||||||
|
def fixture_sharegpt_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "hello"},
|
||||||
|
{"from": "gpt", "value": "hello"},
|
||||||
|
{"from": "human", "value": "goodbye"},
|
||||||
|
{"from": "gpt", "value": "goodbye"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="basic_dataset")
|
||||||
|
def fixture_basic_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "You are an AI assistant."},
|
||||||
|
{"from": "human", "value": "Hello"},
|
||||||
|
{"from": "assistant", "value": "Hi there!"},
|
||||||
|
{"from": "human", "value": "How are you?"},
|
||||||
|
{"from": "assistant", "value": "I'm doing well, thank you!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="phi35_tokenizer")
|
||||||
|
def fixture_phi35_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
||||||
|
return tokenizer
|
||||||
@@ -5,10 +5,6 @@ tests for chat_template prompt strategy
|
|||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datasets import Dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.chat_template import (
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
ChatTemplatePrompter,
|
ChatTemplatePrompter,
|
||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
@@ -22,657 +18,6 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
|
||||||
def fixture_assistant_dataset():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "hello"},
|
|
||||||
{"role": "assistant", "content": "hello"},
|
|
||||||
{"role": "user", "content": "goodbye"},
|
|
||||||
{"role": "assistant", "content": "goodbye"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset")
|
|
||||||
def fixture_sharegpt_dataset():
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"from": "human", "value": "hello"},
|
|
||||||
{"from": "gpt", "value": "hello"},
|
|
||||||
{"from": "human", "value": "goodbye"},
|
|
||||||
{"from": "gpt", "value": "goodbye"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="basic_dataset")
|
|
||||||
def fixture_basic_dataset():
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "You are an AI assistant."},
|
|
||||||
{"from": "human", "value": "Hello"},
|
|
||||||
{"from": "assistant", "value": "Hi there!"},
|
|
||||||
{"from": "human", "value": "How are you?"},
|
|
||||||
{"from": "assistant", "value": "I'm doing well, thank you!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
|
||||||
def fixture_llama3_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatTemplateConfigurations:
|
|
||||||
"""
|
|
||||||
Test class for various configurations of ChatTemplateStrategy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def find_sublist(full_list, sub_list):
|
|
||||||
token_count = len(sub_list)
|
|
||||||
for index in range(len(full_list) - token_count + 1):
|
|
||||||
if full_list[index : index + token_count] == sub_list:
|
|
||||||
return index
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with train_on_inputs=True")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=True,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Verify that assistant responses are labeled
|
|
||||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
|
||||||
for response in assistant_responses:
|
|
||||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, response_ids)
|
|
||||||
LOG.debug(
|
|
||||||
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
|
||||||
)
|
|
||||||
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
|
||||||
assert all(
|
|
||||||
label != IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
|
||||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
|
||||||
|
|
||||||
# Check the behavior of human inputs
|
|
||||||
human_inputs = ["Hello", "How are you?"]
|
|
||||||
for input_text in human_inputs:
|
|
||||||
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, input_ids)
|
|
||||||
labeled = all(
|
|
||||||
label != IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(input_ids)]
|
|
||||||
)
|
|
||||||
LOG.debug(
|
|
||||||
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug("Full labels: %s", labels)
|
|
||||||
LOG.debug("Full input_ids: %s", input_ids)
|
|
||||||
|
|
||||||
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with train_on_inputs=False")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Verify that only assistant responses are labeled
|
|
||||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
|
||||||
for response in assistant_responses:
|
|
||||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, response_ids)
|
|
||||||
LOG.debug(
|
|
||||||
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
|
||||||
)
|
|
||||||
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
|
||||||
assert all(
|
|
||||||
label != IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
|
||||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
|
||||||
|
|
||||||
# Verify that human inputs are not labeled
|
|
||||||
human_inputs = ["Hello", "How are you?"]
|
|
||||||
for input_text in human_inputs:
|
|
||||||
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, input_ids)
|
|
||||||
LOG.debug(
|
|
||||||
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
|
|
||||||
)
|
|
||||||
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
|
||||||
assert all(
|
|
||||||
label == IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(input_ids)]
|
|
||||||
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
|
|
||||||
|
|
||||||
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing roles_to_train with assistant only")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Verify that only assistant responses are labeled
|
|
||||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
|
||||||
for response in assistant_responses:
|
|
||||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, response_ids)
|
|
||||||
LOG.debug(
|
|
||||||
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
|
||||||
)
|
|
||||||
assert all(
|
|
||||||
label != IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
|
||||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
|
||||||
|
|
||||||
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing roles_to_train with all roles")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=True,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["human", "assistant"],
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Verify that all responses are labeled (except for special tokens)
|
|
||||||
all_responses = [
|
|
||||||
"Hello",
|
|
||||||
"Hi there!",
|
|
||||||
"How are you?",
|
|
||||||
"I'm doing well, thank you!",
|
|
||||||
]
|
|
||||||
for response in all_responses:
|
|
||||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, response_ids)
|
|
||||||
LOG.debug(
|
|
||||||
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
|
||||||
)
|
|
||||||
assert all(
|
|
||||||
label != IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
|
||||||
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
|
||||||
|
|
||||||
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with empty roles_to_train")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=[],
|
|
||||||
train_on_eos="none", # Add this line
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
|
|
||||||
# Verify that no labels are set when roles_to_train is empty
|
|
||||||
LOG.debug("Full labels: %s", labels)
|
|
||||||
assert all(
|
|
||||||
label == IGNORE_TOKEN_ID for label in labels
|
|
||||||
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
|
||||||
|
|
||||||
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with train_on_eos='all'")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
train_on_eos="all",
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
eos_token_id = llama3_tokenizer.eos_token_id
|
|
||||||
eos_indices = [
|
|
||||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
|
||||||
for eos_idx in eos_indices:
|
|
||||||
assert (
|
|
||||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
|
||||||
), f"Expected EOS token at index {eos_idx} to be labeled"
|
|
||||||
|
|
||||||
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with train_on_eos='turn'")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
train_on_eos="turn",
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
eos_token_id = llama3_tokenizer.eos_token_id
|
|
||||||
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
|
||||||
|
|
||||||
for response in assistant_responses:
|
|
||||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, response_ids)
|
|
||||||
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
|
||||||
|
|
||||||
eos_idx = start_idx + len(response_ids)
|
|
||||||
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
|
||||||
eos_idx += 1
|
|
||||||
|
|
||||||
assert eos_idx < len(
|
|
||||||
input_ids
|
|
||||||
), f"Could not find EOS token after '{response}'"
|
|
||||||
assert (
|
|
||||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
|
||||||
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
|
||||||
|
|
||||||
# Check that EOS tokens after human inputs are not labeled
|
|
||||||
human_inputs = ["Hello", "How are you?"]
|
|
||||||
for input_text in human_inputs:
|
|
||||||
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, input_ids)
|
|
||||||
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
|
||||||
|
|
||||||
eos_idx = start_idx + len(input_ids)
|
|
||||||
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
|
||||||
eos_idx += 1
|
|
||||||
|
|
||||||
assert (
|
|
||||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
|
||||||
), f"Expected EOS token after human input '{input_text}' to not be labeled"
|
|
||||||
|
|
||||||
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with train_on_eos='last'")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
train_on_eos="last",
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
eos_token_id = llama3_tokenizer.eos_token_id
|
|
||||||
eos_indices = [
|
|
||||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
|
||||||
last_eos_idx = eos_indices[-1]
|
|
||||||
|
|
||||||
# Check that only the last EOS token is labeled
|
|
||||||
for idx in eos_indices[:-1]:
|
|
||||||
assert (
|
|
||||||
labels[idx] == IGNORE_TOKEN_ID
|
|
||||||
), f"Expected EOS token at index {idx} to not be labeled"
|
|
||||||
assert (
|
|
||||||
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
|
||||||
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
|
||||||
|
|
||||||
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with train_on_eos='none'")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
train_on_eos="none",
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
eos_token_id = llama3_tokenizer.eos_token_id
|
|
||||||
eos_indices = [
|
|
||||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
|
||||||
for eos_idx in eos_indices:
|
|
||||||
assert (
|
|
||||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
|
||||||
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
|
||||||
|
|
||||||
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
|
|
||||||
LOG.info("Testing with drop_system_message=True")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(
|
|
||||||
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
|
||||||
),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["assistant"],
|
|
||||||
)
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Check if system message is not present in input_ids
|
|
||||||
system_message = "You are an AI assistant."
|
|
||||||
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
|
|
||||||
assert (
|
|
||||||
self.find_sublist(input_ids, system_ids) == -1
|
|
||||||
), "Expected system message to be dropped"
|
|
||||||
|
|
||||||
def test_custom_roles(self, llama3_tokenizer):
|
|
||||||
LOG.info("Testing with custom roles mapping")
|
|
||||||
custom_roles = {
|
|
||||||
"user": ["human", "user"],
|
|
||||||
"assistant": ["ai", "assistant"],
|
|
||||||
"system": ["context"],
|
|
||||||
}
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(
|
|
||||||
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
|
||||||
),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=["ai"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a new dataset with modified role names
|
|
||||||
modified_conversations = [
|
|
||||||
{"from": "context", "value": "You are an AI assistant."},
|
|
||||||
{"from": "human", "value": "Hello"},
|
|
||||||
{"from": "ai", "value": "Hi there!"},
|
|
||||||
{"from": "human", "value": "How are you?"},
|
|
||||||
{"from": "ai", "value": "I'm doing well, thank you!"},
|
|
||||||
]
|
|
||||||
|
|
||||||
modified_dataset = Dataset.from_dict(
|
|
||||||
{"conversations": [modified_conversations]}
|
|
||||||
)
|
|
||||||
|
|
||||||
res = strategy.tokenize_prompt(modified_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Check if AI responses are labeled correctly
|
|
||||||
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
|
|
||||||
for response in ai_responses:
|
|
||||||
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, response_ids)
|
|
||||||
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
|
|
||||||
assert all(
|
|
||||||
label != IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(response_ids)]
|
|
||||||
), f"Expected labels for AI response '{response}' to be set"
|
|
||||||
|
|
||||||
# Check if human messages are not labeled
|
|
||||||
human_messages = ["Hello", "How are you?"]
|
|
||||||
for message in human_messages:
|
|
||||||
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
|
|
||||||
start_idx = self.find_sublist(input_ids, message_ids)
|
|
||||||
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
|
|
||||||
assert all(
|
|
||||||
label == IGNORE_TOKEN_ID
|
|
||||||
for label in labels[start_idx : start_idx + len(message_ids)]
|
|
||||||
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
|
|
||||||
|
|
||||||
def test_message_field_training(self, llama3_tokenizer):
|
|
||||||
LOG.info("Testing with message_field_training")
|
|
||||||
strategy = ChatTemplateStrategy(
|
|
||||||
ChatTemplatePrompter(
|
|
||||||
llama3_tokenizer,
|
|
||||||
chat_templates("llama3"),
|
|
||||||
message_field_training="train",
|
|
||||||
message_field_training_detail="train_detail",
|
|
||||||
),
|
|
||||||
tokenizer=llama3_tokenizer,
|
|
||||||
train_on_inputs=False,
|
|
||||||
sequence_len=512,
|
|
||||||
roles_to_train=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a new dataset with the train and train_detail fields
|
|
||||||
modified_conversation = [
|
|
||||||
{"from": "system", "value": "You are an AI assistant.", "train": False},
|
|
||||||
{"from": "human", "value": "Hello", "train": False},
|
|
||||||
{"from": "assistant", "value": "Hello", "train": True},
|
|
||||||
{"from": "human", "value": "How are you?", "train": True},
|
|
||||||
{
|
|
||||||
"from": "assistant",
|
|
||||||
"value": "I'm doing very well, thank you!",
|
|
||||||
"train_detail": [
|
|
||||||
{"begin_offset": 0, "end_offset": 8, "train": False},
|
|
||||||
{"begin_offset": 9, "end_offset": 18, "train": True},
|
|
||||||
{"begin_offset": 19, "end_offset": 30, "train": False},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "I'm doing very well, thank you!",
|
|
||||||
"train": False,
|
|
||||||
},
|
|
||||||
{"from": "assistant", "value": "Hi there!", "train": True},
|
|
||||||
]
|
|
||||||
|
|
||||||
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
|
|
||||||
|
|
||||||
res = strategy.tokenize_prompt(modified_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
# Function to find all occurrences of a sublist
|
|
||||||
def find_all_sublists(full_list, sub_list):
|
|
||||||
indices = []
|
|
||||||
for index in range(len(full_list) - len(sub_list) + 1):
|
|
||||||
if full_list[index : index + len(sub_list)] == sub_list:
|
|
||||||
indices.append(index)
|
|
||||||
return indices
|
|
||||||
|
|
||||||
# Keep track of which occurrences we've processed
|
|
||||||
processed_occurrences = {}
|
|
||||||
# Check if messages are labeled correctly based on train or train_detail
|
|
||||||
for i, turn in enumerate(modified_conversation):
|
|
||||||
turn_tokens = llama3_tokenizer.encode(
|
|
||||||
turn["value"], add_special_tokens=False
|
|
||||||
)
|
|
||||||
occurrences = find_all_sublists(input_ids, turn_tokens)
|
|
||||||
turn_key = turn["value"]
|
|
||||||
if turn_key not in processed_occurrences:
|
|
||||||
processed_occurrences[turn_key] = 0
|
|
||||||
current_occurrence = processed_occurrences[turn_key]
|
|
||||||
|
|
||||||
if current_occurrence >= len(occurrences):
|
|
||||||
assert (
|
|
||||||
False
|
|
||||||
), f"Not enough occurrences found for message: {turn['value']}"
|
|
||||||
|
|
||||||
start_idx = occurrences[current_occurrence]
|
|
||||||
processed_occurrences[turn_key] += 1
|
|
||||||
end_idx = start_idx + len(turn_tokens)
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "train_detail" in turn:
|
|
||||||
# Get token offsets
|
|
||||||
tokenized_output = llama3_tokenizer(
|
|
||||||
turn["value"], return_offsets_mapping=True, add_special_tokens=False
|
|
||||||
)
|
|
||||||
token_offsets = tokenized_output["offset_mapping"]
|
|
||||||
|
|
||||||
# Adjust token offsets as done in the implementation
|
|
||||||
for i in range(len(token_offsets) - 1):
|
|
||||||
token_offsets[i] = (
|
|
||||||
token_offsets[i][0],
|
|
||||||
token_offsets[i + 1][0] - 1,
|
|
||||||
)
|
|
||||||
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
|
|
||||||
|
|
||||||
# Adjust train_details
|
|
||||||
adjusted_train_details = strategy.prompter.adjust_train_details(
|
|
||||||
turn["train_detail"], token_offsets
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug(f"Original train_details: {turn['train_detail']}")
|
|
||||||
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
|
|
||||||
|
|
||||||
# Handle train_detail
|
|
||||||
token_offsets = strategy.prompter.get_offsets_for_train_detail(
|
|
||||||
text=turn["value"],
|
|
||||||
train_details=adjusted_train_details,
|
|
||||||
mask_untrainable=False,
|
|
||||||
)
|
|
||||||
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
|
|
||||||
text=turn["value"],
|
|
||||||
train_details=adjusted_train_details,
|
|
||||||
mask_untrainable=True,
|
|
||||||
)
|
|
||||||
LOG.debug(f"Token offsets: {token_offsets_masked}")
|
|
||||||
|
|
||||||
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
|
|
||||||
for i, offset in enumerate(token_offsets_masked):
|
|
||||||
if offset != IGNORE_TOKEN_ID:
|
|
||||||
expected_labels[i] = turn_tokens[i]
|
|
||||||
actual_labels = labels[
|
|
||||||
start_idx : start_idx + len(token_offsets_masked)
|
|
||||||
]
|
|
||||||
assert (
|
|
||||||
actual_labels == expected_labels
|
|
||||||
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
|
||||||
|
|
||||||
for detail in adjusted_train_details:
|
|
||||||
# Find the token indices that correspond to the character offsets
|
|
||||||
detail_start = start_idx + next(
|
|
||||||
i
|
|
||||||
for i, offset in enumerate(token_offsets)
|
|
||||||
if offset >= detail["begin_offset"]
|
|
||||||
)
|
|
||||||
detail_end = start_idx + next(
|
|
||||||
(
|
|
||||||
i
|
|
||||||
for i, offset in enumerate(token_offsets)
|
|
||||||
if offset > detail["end_offset"]
|
|
||||||
),
|
|
||||||
len(token_offsets),
|
|
||||||
)
|
|
||||||
|
|
||||||
detail_text = turn["value"][
|
|
||||||
detail["begin_offset"] : detail["end_offset"] + 1
|
|
||||||
]
|
|
||||||
detail_labels = labels[detail_start:detail_end]
|
|
||||||
detail_input_ids = input_ids[detail_start:detail_end]
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
|
|
||||||
)
|
|
||||||
LOG.debug(f"Detail input_ids: {detail_input_ids}")
|
|
||||||
LOG.debug(f"Detail labels: {detail_labels}")
|
|
||||||
LOG.debug(
|
|
||||||
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
|
|
||||||
)
|
|
||||||
LOG.debug(
|
|
||||||
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if detail["train"]:
|
|
||||||
assert all(
|
|
||||||
label != IGNORE_TOKEN_ID for label in detail_labels
|
|
||||||
), (
|
|
||||||
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
|
|
||||||
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
|
||||||
f"InputIDs: {detail_input_ids}, "
|
|
||||||
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert all(
|
|
||||||
label == IGNORE_TOKEN_ID for label in detail_labels
|
|
||||||
), (
|
|
||||||
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
|
|
||||||
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
|
||||||
f"InputIDs: {detail_input_ids}, "
|
|
||||||
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
should_train = turn.get("train", False)
|
|
||||||
turn_labels = labels[start_idx:end_idx]
|
|
||||||
|
|
||||||
LOG.debug(f"Should train: {should_train}")
|
|
||||||
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
|
|
||||||
LOG.debug(f"Turn labels: {turn_labels}")
|
|
||||||
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
|
|
||||||
LOG.debug(
|
|
||||||
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_train:
|
|
||||||
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
|
||||||
f"Expected all labels for '{turn['value']}' to be set\n"
|
|
||||||
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
|
||||||
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
|
||||||
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
|
||||||
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
|
|
||||||
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
|
||||||
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
|
||||||
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
|
|
||||||
f"start_idx: {start_idx}, end_idx: {end_idx}, "
|
|
||||||
f"labels: {labels[start_idx:end_idx]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug(f"Final labels: {labels}")
|
|
||||||
LOG.debug(f"Final input_ids: {input_ids}")
|
|
||||||
|
|
||||||
|
|
||||||
class TestAssistantChatTemplateLlama3:
|
class TestAssistantChatTemplateLlama3:
|
||||||
"""
|
"""
|
||||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
@@ -740,7 +85,6 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["assistant"],
|
|
||||||
)
|
)
|
||||||
strategy.messages = "messages"
|
strategy.messages = "messages"
|
||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
@@ -764,6 +108,64 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
input_ids == expected_input_ids
|
input_ids == expected_input_ids
|
||||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
|
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
||||||
|
LOG.info("Testing phi-3.5 with assistant dataset")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
phi35_tokenizer,
|
||||||
|
chat_templates("phi_35"),
|
||||||
|
message_field_role="role",
|
||||||
|
message_field_content="content",
|
||||||
|
roles={
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tokenizer=phi35_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
)
|
||||||
|
strategy.messages = "messages"
|
||||||
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
32010, # user
|
||||||
|
22172, 32007, # user eot
|
||||||
|
32001, # assistant
|
||||||
|
22172, 32007, # assistant eot
|
||||||
|
32010, # user
|
||||||
|
1781, 26966, 32007, # user eot
|
||||||
|
32001, # assistant
|
||||||
|
1781, 26966, 32007, # assistant eot
|
||||||
|
32000, # eos
|
||||||
|
]
|
||||||
|
expected_labels = [
|
||||||
|
-100, # user
|
||||||
|
-100, -100, # user eot
|
||||||
|
-100, # assistant
|
||||||
|
-100, -100, # assistant eot,
|
||||||
|
-100, # user
|
||||||
|
-100, -100, -100, # user eot
|
||||||
|
-100, # assistant
|
||||||
|
1781, 26966, 32007, # assistant eot
|
||||||
|
32000, # eos
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
|
LOG.debug(f"Expected labels : {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels : {labels}")
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Input IDs mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
||||||
LOG.info("Testing llama-3 with assistant dataset including training data")
|
LOG.info("Testing llama-3 with assistant dataset including training data")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
|
|||||||
615
tests/prompt_strategies/test_chat_templates_advanced.py
Normal file
615
tests/prompt_strategies/test_chat_templates_advanced.py
Normal file
@@ -0,0 +1,615 @@
|
|||||||
|
"""
|
||||||
|
tests for chat_template prompt strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
|
ChatTemplatePrompter,
|
||||||
|
ChatTemplateStrategy,
|
||||||
|
)
|
||||||
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplateConfigurations:
|
||||||
|
"""
|
||||||
|
Test class for various configurations of ChatTemplateStrategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_sublist(full_list, sub_list):
|
||||||
|
token_count = len(sub_list)
|
||||||
|
for index in range(len(full_list) - token_count + 1):
|
||||||
|
if full_list[index : index + token_count] == sub_list:
|
||||||
|
return index
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=True,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
# Check the behavior of human inputs
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
labeled = all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug("Full labels: %s", labels)
|
||||||
|
LOG.debug("Full input_ids: %s", input_ids)
|
||||||
|
|
||||||
|
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_inputs=False")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that only assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
# Verify that human inputs are not labeled
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||||
|
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
|
||||||
|
|
||||||
|
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing roles_to_train with assistant only")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that only assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing roles_to_train with all roles")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=True,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["human", "assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that all responses are labeled (except for special tokens)
|
||||||
|
all_responses = [
|
||||||
|
"Hello",
|
||||||
|
"Hi there!",
|
||||||
|
"How are you?",
|
||||||
|
"I'm doing well, thank you!",
|
||||||
|
]
|
||||||
|
for response in all_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with empty roles_to_train")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=[],
|
||||||
|
train_on_eos="none", # Add this line
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
|
||||||
|
# Verify that no labels are set when roles_to_train is empty
|
||||||
|
LOG.debug("Full labels: %s", labels)
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID for label in labels
|
||||||
|
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||||
|
|
||||||
|
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='all'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="all",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
for eos_idx in eos_indices:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {eos_idx} to be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='turn'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="turn",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
|
||||||
|
eos_idx = start_idx + len(response_ids)
|
||||||
|
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||||
|
eos_idx += 1
|
||||||
|
|
||||||
|
assert eos_idx < len(
|
||||||
|
input_ids
|
||||||
|
), f"Could not find EOS token after '{response}'"
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||||
|
|
||||||
|
# Check that EOS tokens after human inputs are not labeled
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||||
|
|
||||||
|
eos_idx = start_idx + len(input_ids)
|
||||||
|
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||||
|
eos_idx += 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token after human input '{input_text}' to not be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='last'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="last",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
last_eos_idx = eos_indices[-1]
|
||||||
|
|
||||||
|
# Check that only the last EOS token is labeled
|
||||||
|
for idx in eos_indices[:-1]:
|
||||||
|
assert (
|
||||||
|
labels[idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {idx} to not be labeled"
|
||||||
|
assert (
|
||||||
|
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='none'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="none",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
for eos_idx in eos_indices:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||||
|
|
||||||
|
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with drop_system_message=True")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Check if system message is not present in input_ids
|
||||||
|
system_message = "You are an AI assistant."
|
||||||
|
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
|
||||||
|
assert (
|
||||||
|
self.find_sublist(input_ids, system_ids) == -1
|
||||||
|
), "Expected system message to be dropped"
|
||||||
|
|
||||||
|
def test_custom_roles(self, llama3_tokenizer):
|
||||||
|
LOG.info("Testing with custom roles mapping")
|
||||||
|
custom_roles = {
|
||||||
|
"user": ["human", "user"],
|
||||||
|
"assistant": ["ai", "assistant"],
|
||||||
|
"system": ["context"],
|
||||||
|
}
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["ai"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new dataset with modified role names
|
||||||
|
modified_conversations = [
|
||||||
|
{"from": "context", "value": "You are an AI assistant."},
|
||||||
|
{"from": "human", "value": "Hello"},
|
||||||
|
{"from": "ai", "value": "Hi there!"},
|
||||||
|
{"from": "human", "value": "How are you?"},
|
||||||
|
{"from": "ai", "value": "I'm doing well, thank you!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
modified_dataset = Dataset.from_dict(
|
||||||
|
{"conversations": [modified_conversations]}
|
||||||
|
)
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Check if AI responses are labeled correctly
|
||||||
|
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in ai_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for AI response '{response}' to be set"
|
||||||
|
|
||||||
|
# Check if human messages are not labeled
|
||||||
|
human_messages = ["Hello", "How are you?"]
|
||||||
|
for message in human_messages:
|
||||||
|
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, message_ids)
|
||||||
|
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(message_ids)]
|
||||||
|
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
|
||||||
|
|
||||||
|
def test_message_field_training(self, llama3_tokenizer):
|
||||||
|
LOG.info("Testing with message_field_training")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer,
|
||||||
|
chat_templates("llama3"),
|
||||||
|
message_field_training="train",
|
||||||
|
message_field_training_detail="train_detail",
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new dataset with the train and train_detail fields
|
||||||
|
modified_conversation = [
|
||||||
|
{"from": "system", "value": "You are an AI assistant.", "train": False},
|
||||||
|
{"from": "human", "value": "Hello", "train": False},
|
||||||
|
{"from": "assistant", "value": "Hello", "train": True},
|
||||||
|
{"from": "human", "value": "How are you?", "train": True},
|
||||||
|
{
|
||||||
|
"from": "assistant",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train_detail": [
|
||||||
|
{"begin_offset": 0, "end_offset": 8, "train": False},
|
||||||
|
{"begin_offset": 9, "end_offset": 18, "train": True},
|
||||||
|
{"begin_offset": 19, "end_offset": 30, "train": False},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train": False,
|
||||||
|
},
|
||||||
|
{"from": "assistant", "value": "Hi there!", "train": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Function to find all occurrences of a sublist
|
||||||
|
def find_all_sublists(full_list, sub_list):
|
||||||
|
indices = []
|
||||||
|
for index in range(len(full_list) - len(sub_list) + 1):
|
||||||
|
if full_list[index : index + len(sub_list)] == sub_list:
|
||||||
|
indices.append(index)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
# Keep track of which occurrences we've processed
|
||||||
|
processed_occurrences = {}
|
||||||
|
# Check if messages are labeled correctly based on train or train_detail
|
||||||
|
for i, turn in enumerate(modified_conversation):
|
||||||
|
turn_tokens = llama3_tokenizer.encode(
|
||||||
|
turn["value"], add_special_tokens=False
|
||||||
|
)
|
||||||
|
occurrences = find_all_sublists(input_ids, turn_tokens)
|
||||||
|
turn_key = turn["value"]
|
||||||
|
if turn_key not in processed_occurrences:
|
||||||
|
processed_occurrences[turn_key] = 0
|
||||||
|
current_occurrence = processed_occurrences[turn_key]
|
||||||
|
|
||||||
|
if current_occurrence >= len(occurrences):
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), f"Not enough occurrences found for message: {turn['value']}"
|
||||||
|
|
||||||
|
start_idx = occurrences[current_occurrence]
|
||||||
|
processed_occurrences[turn_key] += 1
|
||||||
|
end_idx = start_idx + len(turn_tokens)
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "train_detail" in turn:
|
||||||
|
# Get token offsets
|
||||||
|
tokenized_output = llama3_tokenizer(
|
||||||
|
turn["value"], return_offsets_mapping=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
token_offsets = tokenized_output["offset_mapping"]
|
||||||
|
|
||||||
|
# Adjust token offsets as done in the implementation
|
||||||
|
for i in range(len(token_offsets) - 1):
|
||||||
|
token_offsets[i] = (
|
||||||
|
token_offsets[i][0],
|
||||||
|
token_offsets[i + 1][0] - 1,
|
||||||
|
)
|
||||||
|
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
|
||||||
|
|
||||||
|
# Adjust train_details
|
||||||
|
adjusted_train_details = strategy.prompter.adjust_train_details(
|
||||||
|
turn["train_detail"], token_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Original train_details: {turn['train_detail']}")
|
||||||
|
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
|
||||||
|
|
||||||
|
# Handle train_detail
|
||||||
|
token_offsets = strategy.prompter.get_offsets_for_train_detail(
|
||||||
|
text=turn["value"],
|
||||||
|
train_details=adjusted_train_details,
|
||||||
|
mask_untrainable=False,
|
||||||
|
)
|
||||||
|
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
|
||||||
|
text=turn["value"],
|
||||||
|
train_details=adjusted_train_details,
|
||||||
|
mask_untrainable=True,
|
||||||
|
)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets_masked}")
|
||||||
|
|
||||||
|
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
|
||||||
|
for i, offset in enumerate(token_offsets_masked):
|
||||||
|
if offset != IGNORE_TOKEN_ID:
|
||||||
|
expected_labels[i] = turn_tokens[i]
|
||||||
|
actual_labels = labels[
|
||||||
|
start_idx : start_idx + len(token_offsets_masked)
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
actual_labels == expected_labels
|
||||||
|
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||||
|
|
||||||
|
for detail in adjusted_train_details:
|
||||||
|
# Find the token indices that correspond to the character offsets
|
||||||
|
detail_start = start_idx + next(
|
||||||
|
i
|
||||||
|
for i, offset in enumerate(token_offsets)
|
||||||
|
if offset >= detail["begin_offset"]
|
||||||
|
)
|
||||||
|
detail_end = start_idx + next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, offset in enumerate(token_offsets)
|
||||||
|
if offset > detail["end_offset"]
|
||||||
|
),
|
||||||
|
len(token_offsets),
|
||||||
|
)
|
||||||
|
|
||||||
|
detail_text = turn["value"][
|
||||||
|
detail["begin_offset"] : detail["end_offset"] + 1
|
||||||
|
]
|
||||||
|
detail_labels = labels[detail_start:detail_end]
|
||||||
|
detail_input_ids = input_ids[detail_start:detail_end]
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
|
||||||
|
)
|
||||||
|
LOG.debug(f"Detail input_ids: {detail_input_ids}")
|
||||||
|
LOG.debug(f"Detail labels: {detail_labels}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if detail["train"]:
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID for label in detail_labels
|
||||||
|
), (
|
||||||
|
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
|
||||||
|
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||||
|
f"InputIDs: {detail_input_ids}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID for label in detail_labels
|
||||||
|
), (
|
||||||
|
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
|
||||||
|
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||||
|
f"InputIDs: {detail_input_ids}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
should_train = turn.get("train", False)
|
||||||
|
turn_labels = labels[start_idx:end_idx]
|
||||||
|
|
||||||
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
|
||||||
|
LOG.debug(f"Turn labels: {turn_labels}")
|
||||||
|
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_train:
|
||||||
|
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected all labels for '{turn['value']}' to be set\n"
|
||||||
|
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||||
|
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
|
||||||
|
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||||
|
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
|
||||||
|
f"start_idx: {start_idx}, end_idx: {end_idx}, "
|
||||||
|
f"labels: {labels[start_idx:end_idx]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Final labels: {labels}")
|
||||||
|
LOG.debug(f"Final input_ids: {input_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user