From 55cc214c767741e83ee7b346e5e13e6c03b7b9fa Mon Sep 17 00:00:00 2001 From: Adam Brusselback Date: Sun, 28 Jul 2024 21:48:57 -0400 Subject: [PATCH] Add flexible configuration options for `chat_template` dataset training (#1756) * Add flexible configuration options for chat dataset training - Introduce roles_to_train parameter to set training labels by role - Add train_on_eos option to configure training on end-of-sequence tokens - Implement per-message training configuration in dataset - Allow fine-grained control over training specific portions of messages - Add message_field_training and message_field_training_detail settings - Implement mapping between dataset character offsets and tokenized prompt - Enhance test suite to cover new functionality * Fix missing field inits, things weren't working from yaml. * Add flexible configuration options for chat dataset training - Introduce roles_to_train parameter to set training labels by role - Add train_on_eos option to configure training on end-of-sequence tokens - Implement per-message training configuration in dataset - Allow fine-grained control over training specific portions of messages - Add message_field_training and message_field_training_detail settings - Implement mapping between dataset character offsets and tokenized prompt - Enhance test suite to cover new functionality * Fix missing field inits, things weren't working from yaml. * chore: lint * Revert test repo back to NousResearch after opening PR to fix the tokenizer_config.json. --------- Co-authored-by: Wing Lian --- .../prompt_strategies/chat_template.py | 313 ++++++- .../config/models/input/v0_4_1/__init__.py | 4 + .../prompt_strategies/test_chat_templates.py | 884 +++++++++++++++++- tests/prompt_strategies/test_sharegpt.py | 2 + 4 files changed, 1111 insertions(+), 92 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8c7a8dd4f..f9fa71f21 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -6,14 +6,16 @@ import logging from typing import Any, Dict, List, Optional from axolotl.prompt_tokenizers import PromptTokenizingStrategy -from axolotl.prompters import Prompter +from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import chat_templates +# Configure the logger +logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") class ChatTemplatePrompter(Prompter): - """prompter for HF chat templates""" + """Prompter for HF chat templates""" def __init__( self, @@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter): max_length=2048, message_field_role: str = "from", message_field_content: str = "value", + message_field_training: str = "train", + message_field_training_detail: str = "train_detail", roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter): } self.message_field_role = message_field_role self.message_field_content = message_field_content + self.message_field_training = message_field_training + self.message_field_training_detail = message_field_training_detail self.tokenizer = tokenizer self.chat_template = chat_template self.max_length = max_length @@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter): { "role": self.roles[t[self.message_field_role]], "content": t[self.message_field_content], + "training": t.get(self.message_field_training, None), } for t in conversation ] @@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter): chat_template=self.chat_template, ) + def get_offsets_for_train_detail( + self, text: str, train_details: List[Dict], mask_untrainable: bool = True + ) -> List[int]: + tokenized_output = self.tokenizer( + text, return_offsets_mapping=True, add_special_tokens=False + ) + tokens = tokenized_output.tokens() + token_offsets = tokenized_output["offset_mapping"] + + LOG.debug(f"Tokenizing text: {text}") + LOG.debug(f"Tokens: {tokens}") + # Adjust the end offsets. For some reason by default they are set to the same value as the start offsets. + for i in range(len(token_offsets) - 1): + token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1) + # Ensure the last token's end offset is set correctly + token_offsets[-1] = (token_offsets[-1][0], len(text) - 1) + LOG.debug(f"Token offsets: {token_offsets}") + + # Initialize all offsets as IGNORE_TOKEN_ID (not trained) + result = [IGNORE_TOKEN_ID] * len(token_offsets) + + # Adjust train_details to align with token boundaries + adjusted_train_details = self.adjust_train_details(train_details, token_offsets) + + for idx, (start, end) in enumerate(token_offsets): + for detail in adjusted_train_details: + # Check if the token is completely within the detail's range + if start >= detail["begin_offset"] and end <= detail["end_offset"]: + if detail["train"] or not mask_untrainable: + result[idx] = start + LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training") + else: + LOG.debug( + f"Token {idx} ({tokens[idx]}) marked as non-trainable" + ) + elif start < detail["end_offset"] and end > detail["begin_offset"]: + # Token partially overlaps with detail, always mark as non-trainable + LOG.debug( + f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable" + ) + + LOG.debug(f"Final result: {result}") + return result + + def adjust_train_details( + self, train_details: List[Dict], token_offsets: List[tuple] + ) -> List[Dict]: + adjusted_details = [] + for detail in train_details: + begin_offset = detail["begin_offset"] + end_offset = detail["end_offset"] + + # Find the first token that starts after or at the begin_offset + begin_token = next( + ( + i + for i, (t_start, t_end) in enumerate(token_offsets) + if t_start >= begin_offset + ), + len(token_offsets), + ) + if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset: + begin_token -= 1 + + # Find the last token that ends before or at the end_offset + end_token = next( + ( + i + for i in range(len(token_offsets) - 1, -1, -1) + if token_offsets[i][1] <= end_offset + ), + -1, + ) + if ( + end_token < len(token_offsets) - 1 + and token_offsets[end_token + 1][0] < end_offset + ): + end_token += 1 + + if begin_token <= end_token: + adjusted_begin = token_offsets[begin_token][0] + adjusted_end = token_offsets[end_token][1] + + if adjusted_begin != begin_offset or adjusted_end != end_offset: + LOG.warning( + f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})" + ) + + adjusted_details.append( + { + "begin_offset": adjusted_begin, + "end_offset": adjusted_end, + "train": detail["train"], + } + ) + else: + LOG.warning( + f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail." + ) + + return adjusted_details + class ChatTemplateStrategy(PromptTokenizingStrategy): """ @@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): _messages = "conversations" + def __init__( + self, + prompter, + tokenizer, + train_on_inputs, + sequence_len, + roles_to_train=None, + train_on_eos="last", + ): + super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) + self.roles_to_train = roles_to_train if roles_to_train is not None else [] + self.train_on_eos = train_on_eos + @property def messages(self): return self._messages @@ -79,62 +201,169 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self._messages = messages def tokenize_prompt(self, prompt): - turns = self.get_conversation_thread(prompt) - prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True) + turns = prompt[self.messages] input_ids = self.prompter.build_prompt(turns) + labels = [IGNORE_TOKEN_ID] * len(input_ids) - if not self.train_on_inputs: - user_prompt_len = len(prompt_ids) - labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] - else: - labels = input_ids + last_eos_idx = -1 + for index, turn in enumerate(turns): + role = turn.get(self.prompter.message_field_role) + content = turn.get(self.prompter.message_field_content) + train_turn = turn.get(self.prompter.message_field_training) + train_detail = turn.get(self.prompter.message_field_training_detail) - tokenized_prompt = { + LOG.debug( + f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" + ) + + should_train = ( + train_turn + if train_turn is not None + else bool(train_detail is not None) + if train_detail is not None + else self.train_on_inputs or role in self.roles_to_train + ) + + LOG.debug(f"Should train: {should_train}") + + turn_start_idx, turn_end_idx = self.find_turn( + conversation_ids=input_ids, turn=index, turn_content=turn + ) + + LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") + + if should_train and turn_start_idx != -1 and turn_end_idx != -1: + if train_detail: + token_offsets = self.prompter.get_offsets_for_train_detail( + content, train_detail + ) + LOG.debug(f"Token offsets: {token_offsets}") + for i, offset in enumerate(token_offsets): + if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len( + input_ids + ): + labels[turn_start_idx + i] = input_ids[turn_start_idx + i] + LOG.debug( + f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}" + ) + else: + labels[turn_start_idx:turn_end_idx] = input_ids[ + turn_start_idx:turn_end_idx + ] + LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}") + + LOG.debug(f"Labels after processing turn {index}: {labels}") + + # Handle EOS token + eos_idx = self.find_eos_token(input_ids, turn_end_idx) + if eos_idx == turn_end_idx: + last_eos_idx = eos_idx + if self.train_on_eos == "all" or ( + self.train_on_eos == "turn" and should_train + ): + labels[eos_idx] = input_ids[eos_idx] + LOG.debug(f"EOS token set for training at index {eos_idx}") + else: + LOG.debug( + f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}" + ) + + # Handle 'last' option for train_on_eos + if self.train_on_eos == "last" and last_eos_idx != -1: + labels[last_eos_idx] = input_ids[last_eos_idx] + LOG.debug(f"Last EOS token set for training at index {last_eos_idx}") + + LOG.debug(f"Final labels: {labels}") + + return { "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), } - return tokenized_prompt + def find_eos_token(self, input_ids, start_idx): + eos_token_id = self.tokenizer.eos_token_id + for i in range(start_idx, len(input_ids)): + if input_ids[i] == eos_token_id: + return i + return -1 + + def find_turn(self, conversation_ids, turn, turn_content): + """ + Locate the starting and ending indices of the specified turn in a conversation. + + Args: + conversation_ids (list[int]): Token IDs representing the conversation. + turn (int): The turn number to locate (based on EOS tokens). + turn_content (str): String containing the content of the turn. + + Returns: + tuple: (start_idx, end_idx) indices of the start and end of the turn content. + Returns (-1, -1) if the turn content is not found. + """ + content = turn_content.get(self.prompter.message_field_content, "") + content_ids = self.tokenizer.encode(content, add_special_tokens=False) + + eos_token_id = self.tokenizer.eos_token_id + eos_count = 0 + start_search_idx = 0 + + # Locate the starting index after the specified number of EOS tokens + for i, token_id in enumerate(conversation_ids): + if token_id == eos_token_id: + eos_count += 1 + if eos_count == turn: + start_search_idx = ( + i + 1 + ) # Start searching after the specified turn's EOS token + break + + # Find the start index of the content within the conversation + start_idx = -1 + for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1): + if conversation_ids[i : i + len(content_ids)] == content_ids: + start_idx = i + break + + if start_idx != -1: + end_idx = start_idx + len(content_ids) + else: + end_idx = -1 + + return start_idx, end_idx def get_conversation_thread(self, prompt): return prompt[self.messages] def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - chat_template = ( - ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" - ) - message_field_role = ( - ds_cfg["message_field_role"] - if ds_cfg and "message_field_role" in ds_cfg - else "from" - ) - message_field_content = ( - ds_cfg["message_field_content"] - if ds_cfg and "message_field_content" in ds_cfg - else "value" - ) - roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None - drop_system_message = ( - ds_cfg["drop_system_message"] - if ds_cfg and "drop_system_message" in ds_cfg - else False - ) + ds_cfg = ds_cfg or {} + + prompter_params = { + "tokenizer": tokenizer, + "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), + "message_field_role": ds_cfg.get("message_field_role", "from"), + "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_training": ds_cfg.get("message_field_training", "training"), + "message_field_training_detail": ds_cfg.get( + "message_field_training_detail", "train_detail" + ), + "roles": ds_cfg.get("roles"), + "drop_system_message": ds_cfg.get("drop_system_message", False), + } + + strategy_params = { + "train_on_inputs": cfg.train_on_inputs, + "sequence_len": cfg.sequence_len, + "roles_to_train": ds_cfg.get("roles_to_train"), + "train_on_eos": ds_cfg.get("train_on_eos", "last"), + } strategy = ChatTemplateStrategy( - ChatTemplatePrompter( - tokenizer, - chat_templates(chat_template), - message_field_role=message_field_role, - message_field_content=message_field_content, - roles=roles, - drop_system_message=drop_system_message, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, + ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params ) - if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): + + if "field_messages" in ds_cfg and hasattr(strategy, "messages"): strategy.messages = ds_cfg["field_messages"] + return strategy diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 7397c7c73..e92c79485 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -116,6 +116,10 @@ class SFTDataset(BaseModel): field_messages: Optional[str] = None message_field_role: Optional[str] = None message_field_content: Optional[str] = None + message_field_training: Optional[str] = None + message_field_training_detail: Optional[str] = None + roles_to_train: Optional[List[str]] = None + train_on_eos: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 7b58a1236..e2fc0f6a5 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -2,6 +2,7 @@ tests for chat_template prompt strategy """ +import logging import unittest import pytest @@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import ( ChatTemplateStrategy, load, ) +from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + @pytest.fixture(name="assistant_dataset") def fixture_assistant_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { "messages": [ - { - "role": "user", - "content": "hello", - }, - { - "role": "assistant", - "content": "hello", - }, - { - "role": "user", - "content": "goodbye", - }, - { - "role": "assistant", - "content": "goodbye", - }, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "goodbye"}, + {"role": "assistant", "content": "goodbye"}, ] } ] @@ -53,22 +45,28 @@ def fixture_sharegpt_dataset(): [ { "conversations": [ - { - "from": "human", - "value": "hello", - }, - { - "from": "gpt", - "value": "hello", - }, - { - "from": "human", - "value": "goodbye", - }, - { - "from": "gpt", - "value": "goodbye", - }, + {"from": "human", "value": "hello"}, + {"from": "gpt", "value": "hello"}, + {"from": "human", "value": "goodbye"}, + {"from": "gpt", "value": "goodbye"}, + ] + } + ] + ) + + +@pytest.fixture(name="basic_dataset") +def fixture_basic_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "conversations": [ + {"from": "system", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "assistant", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "assistant", "value": "I'm doing well, thank you!"}, ] } ] @@ -77,19 +75,611 @@ def fixture_sharegpt_dataset(): @pytest.fixture(name="llama3_tokenizer") def fixture_llama3_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") - tokenizer.eos_token = "<|eot_id|>" + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") return tokenizer +class TestChatTemplateConfigurations: + """ + Test class for various configurations of ChatTemplateStrategy. + """ + + @staticmethod + def find_sublist(full_list, sub_list): + token_count = len(sub_list) + for index in range(len(full_list) - token_count + 1): + if full_list[index : index + token_count] == sub_list: + return index + return -1 + + def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Check the behavior of human inputs + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + labeled = all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ) + LOG.debug( + f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}" + ) + + LOG.debug("Full labels: %s", labels) + LOG.debug("Full input_ids: %s", input_ids) + + def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_inputs=False") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + # Verify that human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + LOG.debug( + f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}" + ) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(input_ids)] + ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}" + + def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with assistant only") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that only assistant responses are labeled + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing roles_to_train with all roles") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=True, + sequence_len=512, + roles_to_train=["human", "assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Verify that all responses are labeled (except for special tokens) + all_responses = [ + "Hello", + "Hi there!", + "How are you?", + "I'm doing well, thank you!", + ] + for response in all_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + LOG.debug( + f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}" + ) + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}" + + def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with empty roles_to_train") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + train_on_eos="none", # Add this line + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + + # Verify that no labels are set when roles_to_train is empty + LOG.debug("Full labels: %s", labels) + assert all( + label == IGNORE_TOKEN_ID for label in labels + ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + + def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='all'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="all", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to be labeled" + + def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='turn'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="turn", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + assistant_responses = ["Hi there!", "I'm doing well, thank you!"] + + for response in assistant_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find '{response}' in input_ids" + + eos_idx = start_idx + len(response_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert eos_idx < len( + input_ids + ), f"Could not find EOS token after '{response}'" + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOS token after assistant response '{response}' to be labeled" + + # Check that EOS tokens after human inputs are not labeled + human_inputs = ["Hello", "How are you?"] + for input_text in human_inputs: + input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, input_ids) + assert start_idx != -1, f"Could not find '{input_text}' in input_ids" + + eos_idx = start_idx + len(input_ids) + while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: + eos_idx += 1 + + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token after human input '{input_text}' to not be labeled" + + def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='last'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="last", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + last_eos_idx = eos_indices[-1] + + # Check that only the last EOS token is labeled + for idx in eos_indices[:-1]: + assert ( + labels[idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {idx} to not be labeled" + assert ( + labels[last_eos_idx] != IGNORE_TOKEN_ID + ), f"Expected last EOS token at index {last_eos_idx} to be labeled" + + def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with train_on_eos='none'") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="none", + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = llama3_tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert len(eos_indices) > 0, "Expected at least one EOS token in the input" + for eos_idx in eos_indices: + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOS token at index {eos_idx} to not be labeled" + + def test_drop_system_message(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing with drop_system_message=True") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + + # Check if system message is not present in input_ids + system_message = "You are an AI assistant." + system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False) + assert ( + self.find_sublist(input_ids, system_ids) == -1 + ), "Expected system message to be dropped" + + def test_custom_roles(self, llama3_tokenizer): + LOG.info("Testing with custom roles mapping") + custom_roles = { + "user": ["human", "user"], + "assistant": ["ai", "assistant"], + "system": ["context"], + } + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["ai"], + ) + + # Create a new dataset with modified role names + modified_conversations = [ + {"from": "context", "value": "You are an AI assistant."}, + {"from": "human", "value": "Hello"}, + {"from": "ai", "value": "Hi there!"}, + {"from": "human", "value": "How are you?"}, + {"from": "ai", "value": "I'm doing well, thank you!"}, + ] + + modified_dataset = Dataset.from_dict( + {"conversations": [modified_conversations]} + ) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Check if AI responses are labeled correctly + ai_responses = ["Hi there!", "I'm doing well, thank you!"] + for response in ai_responses: + response_ids = llama3_tokenizer.encode(response, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, response_ids) + assert start_idx != -1, f"Could not find response '{response}' in input_ids" + assert all( + label != IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(response_ids)] + ), f"Expected labels for AI response '{response}' to be set" + + # Check if human messages are not labeled + human_messages = ["Hello", "How are you?"] + for message in human_messages: + message_ids = llama3_tokenizer.encode(message, add_special_tokens=False) + start_idx = self.find_sublist(input_ids, message_ids) + assert start_idx != -1, f"Could not find message '{message}' in input_ids" + assert all( + label == IGNORE_TOKEN_ID + for label in labels[start_idx : start_idx + len(message_ids)] + ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID" + + def test_message_field_training(self, llama3_tokenizer): + LOG.info("Testing with message_field_training") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_training="train", + message_field_training_detail="train_detail", + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=[], + ) + + # Create a new dataset with the train and train_detail fields + modified_conversation = [ + {"from": "system", "value": "You are an AI assistant.", "train": False}, + {"from": "human", "value": "Hello", "train": False}, + {"from": "assistant", "value": "Hello", "train": True}, + {"from": "human", "value": "How are you?", "train": True}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": False}, + {"begin_offset": 9, "end_offset": 18, "train": True}, + {"begin_offset": 19, "end_offset": 30, "train": False}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": False, + }, + {"from": "assistant", "value": "Hi there!", "train": True}, + ] + + modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]}) + + res = strategy.tokenize_prompt(modified_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Function to find all occurrences of a sublist + def find_all_sublists(full_list, sub_list): + indices = [] + for index in range(len(full_list) - len(sub_list) + 1): + if full_list[index : index + len(sub_list)] == sub_list: + indices.append(index) + return indices + + # Keep track of which occurrences we've processed + processed_occurrences = {} + # Check if messages are labeled correctly based on train or train_detail + for i, turn in enumerate(modified_conversation): + turn_tokens = llama3_tokenizer.encode( + turn["value"], add_special_tokens=False + ) + occurrences = find_all_sublists(input_ids, turn_tokens) + turn_key = turn["value"] + if turn_key not in processed_occurrences: + processed_occurrences[turn_key] = 0 + current_occurrence = processed_occurrences[turn_key] + + if current_occurrence >= len(occurrences): + assert ( + False + ), f"Not enough occurrences found for message: {turn['value']}" + + start_idx = occurrences[current_occurrence] + processed_occurrences[turn_key] += 1 + end_idx = start_idx + len(turn_tokens) + + LOG.debug( + f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}" + ) + + if "train_detail" in turn: + # Get token offsets + tokenized_output = llama3_tokenizer( + turn["value"], return_offsets_mapping=True, add_special_tokens=False + ) + token_offsets = tokenized_output["offset_mapping"] + + # Adjust token offsets as done in the implementation + for i in range(len(token_offsets) - 1): + token_offsets[i] = ( + token_offsets[i][0], + token_offsets[i + 1][0] - 1, + ) + token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1) + + # Adjust train_details + adjusted_train_details = strategy.prompter.adjust_train_details( + turn["train_detail"], token_offsets + ) + + LOG.debug(f"Original train_details: {turn['train_detail']}") + LOG.debug(f"Adjusted train_details: {adjusted_train_details}") + + # Handle train_detail + token_offsets = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=False, + ) + token_offsets_masked = strategy.prompter.get_offsets_for_train_detail( + text=turn["value"], + train_details=adjusted_train_details, + mask_untrainable=True, + ) + LOG.debug(f"Token offsets: {token_offsets_masked}") + + expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens) + for i, offset in enumerate(token_offsets_masked): + if offset != IGNORE_TOKEN_ID: + expected_labels[i] = turn_tokens[i] + actual_labels = labels[ + start_idx : start_idx + len(token_offsets_masked) + ] + assert ( + actual_labels == expected_labels + ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + + for detail in adjusted_train_details: + # Find the token indices that correspond to the character offsets + detail_start = start_idx + next( + i + for i, offset in enumerate(token_offsets) + if offset >= detail["begin_offset"] + ) + detail_end = start_idx + next( + ( + i + for i, offset in enumerate(token_offsets) + if offset > detail["end_offset"] + ), + len(token_offsets), + ) + + detail_text = turn["value"][ + detail["begin_offset"] : detail["end_offset"] + 1 + ] + detail_labels = labels[detail_start:detail_end] + detail_input_ids = input_ids[detail_start:detail_end] + + LOG.debug( + f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}" + ) + LOG.debug(f"Detail input_ids: {detail_input_ids}") + LOG.debug(f"Detail labels: {detail_labels}") + LOG.debug( + f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}" + ) + LOG.debug( + f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}" + ) + + if detail["train"]: + assert all( + label != IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + assert all( + label == IGNORE_TOKEN_ID for label in detail_labels + ), ( + f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. " + f"Labels({detail_start}:{detail_end}): {detail_labels}, " + f"InputIDs: {detail_input_ids}, " + f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'" + ) + else: + should_train = turn.get("train", False) + turn_labels = labels[start_idx:end_idx] + + LOG.debug(f"Should train: {should_train}") + LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}") + LOG.debug(f"Turn labels: {turn_labels}") + LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}") + LOG.debug( + f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}" + ) + + if should_train: + assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be set\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + else: + assert all(label == IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n" + f"Labels({start_idx}:{end_idx}): {turn_labels}, " + f"InputIDs: {input_ids[start_idx:end_idx]}, " + f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'" + ) + + LOG.debug( + f"Processed turn: {turn['from']}, content: '{turn['value']}', " + f"start_idx: {start_idx}, end_idx: {end_idx}, " + f"labels: {labels[start_idx:end_idx]}" + ) + + LOG.debug(f"Final labels: {labels}") + LOG.debug(f"Final input_ids: {input_ids}") + + class TestAssistantChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. """ def test_llama3_load(self, llama3_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code + LOG.info("Loading llama-3 tokenizer with assistant dataset") strategy = load( llama3_tokenizer, DictDefault( @@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3: res = strategy.tokenize_prompt(assistant_dataset[0]) input_ids = res["input_ids"] # fmt: off - assert input_ids == [ + expected_input_ids = [ 128000, # bos 128006, 882, 128007, # user header 271, 15339, 128009, # user prompt eot 128006, 78191, 128007, # assistant header - 271, 15339, 128009, # assistant response eot + 271, 15339, 128009, # assistant response eot 128006, 882, 128007, 271, 19045, 29474, 128009, 128006, 78191, 128007, 271, 19045, 29474, 128009, ] # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" def test_llama3(self, llama3_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code + LOG.info("Testing llama-3 with assistant dataset") strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, @@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3: "system": ["system"], }, ), - llama3_tokenizer, - False, - 512, + tokenizer=llama3_tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], ) strategy.messages = "messages" res = strategy.tokenize_prompt(assistant_dataset[0]) input_ids = res["input_ids"] # fmt: off - assert input_ids == [ + expected_input_ids = [ 128000, # bos 128006, 882, 128007, # user header 271, 15339, 128009, # user prompt eot @@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3: 271, 19045, 29474, 128009, ] # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + + def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): + LOG.info("Testing llama-3 with assistant dataset including training data") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_role="role", + message_field_content="content", + message_field_training="training", + roles={ + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["assistant"], + ) + strategy.messages = "messages" + prompt_tokens = strategy.prompter.build_prompt( + assistant_dataset[0]["messages"], False + ) + prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False) + LOG.debug(f"Generated prompt: {prompt}") + res = strategy.tokenize_prompt(assistant_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + # fmt: off + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID, + ] + # fmt: on + + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + assert labels == expected_labels, ( + f"Labels mismatch:\n" + f"Expected: {expected_labels}\n" + f"Actual: {labels}\n" + f"Input IDs: {input_ids}\n" + ) class TestSharegptChatTemplateLlama3: @@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3: Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy. """ - def test_llama3(self, llama3_tokenizer, sharegpt_dataset): - # pylint: disable=duplicate-code + def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): + LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") strategy = ChatTemplateStrategy( ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), - llama3_tokenizer, - False, - 512, + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["gpt"], ) res = strategy.tokenize_prompt(sharegpt_dataset[0]) input_ids = res["input_ids"] + labels = res["labels"] # fmt: off - assert input_ids == [ + expected_input_ids = [ 128000, # bos 128006, 882, 128007, # user header 271, 15339, 128009, # user prompt eot 128006, 78191, 128007, # assistant header - 271, 15339, 128009, # assistant response eot + 271, 15339, 128009, # assistant response eot 128006, 882, 128007, 271, 19045, 29474, 128009, 128006, 78191, 128007, 271, 19045, 29474, 128009, ] + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID, + ] # fmt: on + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert ( + labels == expected_labels + ), f"Labels mismatch: {labels} != {expected_labels}" + + def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): + LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["human"], + ) + res = strategy.tokenize_prompt(sharegpt_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + ] + # fmt: on + + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert ( + labels == expected_labels + ), f"Labels mismatch: {labels} != {expected_labels}" + + def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): + LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") + strategy = ChatTemplateStrategy( + ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + tokenizer=llama3_tokenizer, + train_on_inputs=False, + train_on_eos="none", + sequence_len=512, + roles_to_train=["system", "human"], + ) + res = strategy.tokenize_prompt(basic_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + # fmt: off + expected_input_ids = [ + 128000, # bos + 128006, 9125, 128007, + 271, 2675, 527, 459, 15592, 18328, 13, 128009, + 128006, 882, 128007, # user header + 271, 9906, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 13347, 1070, 0, 128009, # assistant response eot + 128006, 882, 128007, + 271, 4438, 527, 499, 30, 128009, + 128006, 78191, 128007, + 271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009, + ] + expected_labels = [ + IGNORE_TOKEN_ID, # bos + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header + IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header + IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, + ] + # fmt: on + + LOG.debug(f"Expected input_ids: {expected_input_ids}") + LOG.debug(f"Actual input_ids: {input_ids}") + LOG.debug(f"Expected labels: {expected_labels}") + LOG.debug(f"Actual labels: {labels}") + + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert ( + labels == expected_labels + ), f"Labels mismatch: {labels} != {expected_labels}" + if __name__ == "__main__": unittest.main() diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index aba53cd5f..e7a73a0de 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -192,6 +192,7 @@ class TestSharegptLlama3: input_ids = dataset_wrapper[0]["input_ids"] # fmt: off + # pylint: disable=duplicate-code assert input_ids == [ 128000, # bos 128006, 9125, 128007, # system header @@ -228,6 +229,7 @@ class TestSharegptLlama3: input_ids = dataset_wrapper[0]["input_ids"] # fmt: off + # pylint: disable=duplicate-code assert input_ids == [ 128000, # bos 128006, 9125, 128007, # system header