diff --git a/docs/config.qmd b/docs/config.qmd index cb39e1d54..8795fa4ab 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -187,7 +187,7 @@ datasets: # IMPORTANT: The following fields determine which parts of the conversation to train on. # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train # See examples at `docs/dataset-formats/conversation.qmd` - # Note: If the below 4 fields are set to empty, defaults to training only on the last message. + # Note: If the below 5 fields are empty, defaults to training only on the last message. # Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss. roles_to_train: ["assistant"] # default @@ -196,7 +196,13 @@ datasets: # - turn (default): train on the EOS token at the end of each trainable turn # - last: train on the last EOS token in the conversation # TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`. - train_on_eos: last + train_on_eos: turn + # Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are: + # - all: train on all EOT tokens + # - turn: train on the EOT token at the end of each trainable turn + # - last: train on the last EOT token in the conversation + # If not specified, defaults to the value of train_on_eos for backward compatibility. + train_on_eot: # The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`. message_field_training: training # The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. @@ -279,8 +285,17 @@ process_reward_model: chat_template: tokenizer_default # custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. chat_template_jinja: null -# Changes the default system message. Currently only supports chatml. -default_system_message: You are a helpful assistant. Please give a long and detailed answer. +# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training. +# These tokens mark the boundaries between conversation turns. +# For example: ["/INST", "", "[/SYSTEM_PROMPT]"] +# If not specified, defaults to just the model's eos_token. +# This is useful for templates that use multiple delimiter tokens. +eot_tokens: + # - "" + # - "[/INST]" + # - "[/SYSTEM_PROMPT]" +# Changes the default system message +default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. # Axolotl attempts to save the dataset as an arrow after packing the data together so # subsequent training attempts load faster, relative path dataset_prepared_path: data/last_run_prepared @@ -665,8 +680,10 @@ special_tokens: # unk_token: "" # pad_token: "[PAD]" -# Add extra tokens. +# Optional[list[str]]. Add extra tokens to the tokenizer. tokens: + # - "<|startoftext|>" + # - "<|endoftext|>" # Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. # Only works for tokens that are not part of the base vocab (aka are added_tokens). diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 81c902afd..ee9f7391b 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -4,18 +4,6 @@ description: Conversation format for supervised fine-tuning. order: 3 --- -## sharegpt - -::: {.callout-important} -ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below. -::: - -## pygmalion - -```{.json filename="data.jsonl"} -{"conversations": [{"role": "...", "value": "..."}]} -``` - ## chat_template Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2. @@ -64,7 +52,7 @@ We recommend checking the below examples for other usecases. ### Examples -1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. +1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. ```yaml datasets: @@ -109,10 +97,55 @@ datasets: ``` ::: {.callout-important} -Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`. +Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `. ::: -5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation +5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn. + +```yaml +eot_tokens: + - "[/INST]" + # - "[/SYSTEM_PROMPT]" + +datasets: + - path: ... + type: chat_template + + # optional + train_on_eot: turn # defaults read from train_on_eos (which defaults to turn) +``` + +::: {.callout-tip} +See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens. +::: + +::: {.callout-note} +Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior. + +You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details. +::: + +6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`. + +```yaml +eot_tokens: + - "[/INST]" + # ... + +datasets: + - path: ... + type: chat_template + + train_on_eos: last + train_on_eot: turn +``` + +::: {.callout-tip} +If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them. +::: + + +7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation For a data sample that looks like: @@ -162,3 +195,15 @@ datasets: ::: {.callout-tip} It is not necessary to set both `message_field_training` and `message_field_training_detail` at once. ::: + +## sharegpt + +::: {.callout-important} +ShareGPT is deprecated!. Please see [chat_template](#chat_template) section. +::: + +## pygmalion + +```{.json filename="data.jsonl"} +{"conversations": [{"role": "...", "value": "..."}]} +``` diff --git a/docs/faq.qmd b/docs/faq.qmd index 664359cb8..f586099e7 100644 --- a/docs/faq.qmd +++ b/docs/faq.qmd @@ -73,10 +73,40 @@ description: Frequently asked questions > A: This is likely an empty turn. -**Q: The EOS/EOT token is incorrectly being masked or not being masked.** +**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.** -> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template. +> A: There can be two reasons: + +> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template. + +> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end. **Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"** > A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details. + +**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.** + +> A: There can be two reasons: + +> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template. + +> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case. + +**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`** + +> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue. + +**Q: `EOT token __ is encoded as multiple tokens.`** + +> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. + +**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`** + +> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `. + +**Q: If `eot_tokens: ` is not provided, what happens?** + +> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable. + +> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens. diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 699728e9f..131570aea 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -35,6 +35,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): sequence_len, roles_to_train=None, train_on_eos=None, + train_on_eot=None, + eot_tokens=None, logprobs_field="logprobs", gen_temperature=1.0, kd_temperature=1.0, @@ -50,6 +52,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): sequence_len, roles_to_train=roles_to_train, train_on_eos=train_on_eos, + train_on_eot=train_on_eot, + eot_tokens=eot_tokens, ) @property diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index d16eb34e1..076ddac1f 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -222,10 +222,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self, prompter: "ChatTemplatePrompter", tokenizer, - train_on_inputs, - sequence_len, - roles_to_train=None, - train_on_eos=None, + train_on_inputs: bool, + sequence_len: int, + roles_to_train: Optional[List[str]] = None, + train_on_eos: Optional[str] = None, + train_on_eot: Optional[str] = None, + eot_tokens: Optional[List[str]] = None, ): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.prompter: ChatTemplatePrompter = prompter @@ -238,12 +240,87 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): ] self.train_on_eos = train_on_eos + # Backward compatibility, load from train_on_eos + self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos + + # Default to eos_token if eot_tokens not provided + self.eot_tokens = ( + eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token] + ) + self.images = "images" LOG.debug( f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}" ) + self._validate_eot_and_eos_tokens() + + def _validate_eot_and_eos_tokens(self): + """ + - Validates that EOT tokens (or eos_token) are in the chat_template + - Checks if EOT tokens are encoded as multiple tokens in the tokenizer. + - Checks for potential conflicts between train_on_eos and train_on_eot. + """ + if self.prompter.chat_template is None: + # Usually this should not happen + LOG.warning( + "No chat template provided, skipping EOT and EOS token validation" + ) + return + + # If the EOT token is the same as the EOS token, we need to check differently + if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token: + # Check if the eos_token is in the chat_template or as a variable `eos_token` + # Note: we check for `eos_token` in the string, but it could possibly not be a variable + if ( + self.tokenizer.eos_token not in self.prompter.chat_template + and "eos_token" not in self.prompter.chat_template + ): + LOG.warning( + f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct." + ) + return + + # Create a new list to store tokens that should be kept + valid_eot_tokens = [] + for token in self.eot_tokens: + # Check if EOT token is in the chat_template + if token not in self.prompter.chat_template: + LOG.warning(f"EOT token '{token}' not found in chat_template.") + # Don't add to the valid tokens list + continue + + valid_eot_tokens.append(token) + + # Replace the original list with the filtered one + self.eot_tokens = valid_eot_tokens + + for token in self.eot_tokens: + # If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens + token_ids = self.tokenizer.encode(token, add_special_tokens=False) + if not token_ids: + raise ValueError( + "EOT token encoding failed. Please check if the token is valid and can be encoded." + ) + if token_ids and len(token_ids) > 1: + raise ValueError( + f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config " + "or (recommended) override unused added_tokens via `added_tokens_overrides: `." + ) + + # If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error + if ( + self.tokenizer.eos_token in self.eot_tokens + and self.train_on_eos != self.train_on_eot + ): + raise ValueError( + "Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot" + f"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}" + f"eot_tokens: {self.eot_tokens}" + f"eos_token: {self.tokenizer.eos_token}" + ) + @property def supports_batched(self) -> bool: # Let calling code know we can handle lists of examples @@ -287,6 +364,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if ( not self.roles_to_train and not self.train_on_eos + and not self.train_on_eot and not self.prompter.message_field_training # type: ignore and not self.prompter.message_field_training_detail # type: ignore ): @@ -322,6 +400,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 + last_eot_idx = -1 for index, turn in enumerate(turns): role = turn.get("role") content = turn.get("content") @@ -370,24 +449,45 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): LOG.debug(f"Labels after processing turn {index}: {labels}") - # Handle EOS token - eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx) - if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding - last_eos_idx = eos_idx - if self.train_on_eos == "all" or ( - self.train_on_eos == "turn" and should_train - ): - labels[eos_idx] = input_ids[eos_idx] - LOG.debug(f"EOS token set for training at index {eos_idx}") - else: - LOG.debug( - f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}" - ) + # Handle special tokens (EOT and EOS) + for token_type, find_func, train_option in [ + ("EOT", self.find_first_eot_token, self.train_on_eot), + ("EOS", self.find_first_eos_token, self.train_on_eos), + ]: + token_idx = find_func(input_ids, start_idx=turn_end_idx) - # Handle 'last' option for train_on_eos - if self.train_on_eos == "last" and last_eos_idx != -1: - labels[last_eos_idx] = input_ids[last_eos_idx] - LOG.debug(f"Last EOS token set for training at index {last_eos_idx}") + if ( + token_idx != -1 and abs(token_idx - turn_end_idx) <= 3 + ): # Allow for some template padding + # Update the last token index + if token_type == "EOT": # nosec B105 + last_eot_idx = token_idx + else: + last_eos_idx = token_idx + + # Set labels if needed for this turn + if train_option == "all" or ( + train_option == "turn" and should_train + ): + labels[token_idx] = input_ids[token_idx] + LOG.debug( + f"{token_type} token set for training at index {token_idx}" + ) + else: + LOG.debug( + f"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}" + ) + + # Handle 'last' option for special tokens + for token_type, last_idx, train_option in [ + ("EOT", last_eot_idx, self.train_on_eot), + ("EOS", last_eos_idx, self.train_on_eos), + ]: + if train_option == "last" and last_idx != -1: + labels[last_idx] = input_ids[last_idx] + LOG.debug( + f"Last {token_type} token set for training at index {last_idx}" + ) LOG.debug(f"Final labels: {labels}") @@ -404,6 +504,25 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return i return -1 + def find_first_eot_token(self, input_ids, start_idx): + """Find the first EOT token in the input_ids starting from start_idx.""" + # Get token IDs for all EOT tokens + eot_token_ids = [] + for token in self.eot_tokens: + token_ids = self.tokenizer.encode(token, add_special_tokens=False) + if len(token_ids) != 1: + raise ValueError( + f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config." + ) + + eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple + + # Search for any of the EOT token IDs + for i in range(start_idx, len(input_ids)): + if input_ids[i] in eot_token_ids: + return i + return -1 + def find_turn(self, turns: list[dict], turn_idx: int): """ Locate the starting and ending indices of the specified turn in a conversation. @@ -568,6 +687,8 @@ class StrategyLoader: "sequence_len": cfg.sequence_len, "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]), "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "train_on_eot": ds_cfg.get("train_on_eot", None), + "eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg } def __call__( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 2e0a6027c..36c18fd3c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -309,6 +309,7 @@ class AxolotlInputConfig( | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] ) | None = None chat_template_jinja: str | None = None + eot_tokens: list[str] | None = None default_system_message: str | None = None fix_untrained_tokens: int | list[int] | None = None diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index ce55b871f..38a5b6c43 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -2,6 +2,8 @@ tests for chat_template prompt strategy """ +# pylint: disable=too-many-lines + import logging from copy import deepcopy @@ -53,14 +55,6 @@ class TestChatTemplateConfigurations: Test class for various configurations of ChatTemplateStrategy. """ - @staticmethod - def find_sublist(full_list, sub_list): - token_count = len(sub_list) - for index in range(len(full_list) - token_count + 1): - if full_list[index : index + token_count] == sub_list: - return index - return -1 - @staticmethod def setup_tokenizer( tokenizer_name, @@ -68,6 +62,7 @@ class TestChatTemplateConfigurations: chat_template_jinja=None, eos_token=None, request=None, + eot_token=None, ) -> tuple[PreTrainedTokenizer, str]: """ Helper function to set up the tokenizer and chat template for the test. @@ -88,6 +83,10 @@ class TestChatTemplateConfigurations: "CodeLlamaTokenizerFast", ): tokenizer.update_post_processor() + + if eot_token: + tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]}) + return tokenizer, chat_template_jinja def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx): @@ -974,3 +973,311 @@ class TestChatTemplateConfigurations: raise ValueError( f"Unsupported chat template: {chat_template} with {chat_template_jinja}" ) + + def test_eot_tokens_conflict_with_eos_token( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + basic_dataset, # pylint: disable=unused-argument + request, + ): + """Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict""" + LOG.info( + "Testing conflict between eot_tokens containing eos_token and train_on_eot/train_on_eos mismatch" + ) + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + # Create a situation where eot_tokens contains eos_token + eot_tokens = [ + tokenizer.eos_token, + "[/INST]", + ] # Deliberately including eos_token + + # Create conflicting train_on_eos and train_on_eot settings + with pytest.raises( + ValueError, + match=".*eos_token is in eot_tokens and train_on_eos != train_on_eot.*", + ): + ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "from", "content": "value"}, + field_messages="conversations", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="none", # Setting to none + train_on_eot="turn", # Different from train_on_eos + eot_tokens=eot_tokens, + ) + + def test_eot_token_backward_compatibility( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + basic_dataset, # pylint: disable=unused-argument + request, + ): + """Test that eot_tokens inherits from eos_token when not specified""" + LOG.info("Testing backward compatibility that eot_token inherits eos_token") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "from", "content": "value"}, + field_messages="conversations", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eos="turn", # Setting train_on_eos to "turn" + ) + + # In backward compatibility mode, eot_tokens should be derived from eos_token + assert strategy.eot_tokens == [ + tokenizer.eos_token + ], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}" + assert ( + strategy.train_on_eot == "turn" + ), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}" + + def test_token_not_in_template( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + basic_dataset, + request, + ): + """Test runs even when tokens are not found in the template""" + LOG.info("Testing runs even when tokens are not found in template") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + # Create a non-existent token that definitely won't be in the template + non_existent_token = "[DEFINITELY_NOT_IN_TEMPLATE]" + tokenizer.add_special_tokens( + {"additional_special_tokens": [non_existent_token]} + ) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "from", "content": "value"}, + field_messages="conversations", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + eot_tokens=[non_existent_token], + ) + + # Force template check by calling tokenize_prompt + strategy.tokenize_prompt(basic_dataset[0]) + + # We can also check that a warning was logged, but there's + # caplog conflicts when running with other tests + # assert any( + # "not found in chat_template" in record.message for record in self._caplog.records + # ), "Expected warning about token not found in template was not logged" + + def test_custom_eot_tokens( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, # pylint: disable=unused-argument + basic_dataset, + request, + ): + """Test with custom EOT tokens to ensure proper masking and training""" + LOG.info("Testing with custom EOT tokens") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, None, request + ) + + # Add custom EOT tokens to the tokenizer + custom_eot = "[EOT]" + tokenizer.add_special_tokens({"additional_special_tokens": [custom_eot]}) + + # Create a custom chat template that uses our EOT token + custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content'] }}{% elif message['role'] == 'user' %}User: {{ message['content'] }}{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}[EOT]{% endif %}{% endfor %}""" + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=custom_template, + message_property_mappings={"role": "from", "content": "value"}, + field_messages="conversations", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eot="turn", # Train on EOT token after each turn + eot_tokens=[custom_eot], + ) + + res = strategy.tokenize_prompt(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Find indices of the EOT token + eot_token_id = tokenizer.convert_tokens_to_ids(custom_eot) + eot_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eot_token_id + ] + + assert len(eot_indices) > 0, "Expected at least one EOT token in the input" + + # Verify labeling for EOT tokens based on role + turns = strategy.get_conversation_thread(basic_dataset[0]) + assistant_turn_indices = [] + non_assistant_turn_indices = [] + + for i, turn in enumerate(basic_dataset[0]["conversations"]): + start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i) + if start_idx != -1 and end_idx != -1: # If turn is found + if turn["from"] == "assistant": + assistant_turn_indices.append((start_idx, end_idx)) + else: + non_assistant_turn_indices.append((start_idx, end_idx)) + + # Check EOT tokens after assistant turns are labeled + for eot_idx in eot_indices: + is_after_assistant = any( + start_idx <= eot_idx <= end_idx + 1 # +1 to include the EOT token + for start_idx, end_idx in assistant_turn_indices + ) + + if is_after_assistant: + assert ( + labels[eot_idx] != IGNORE_TOKEN_ID + ), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled" + else: + assert ( + labels[eot_idx] == IGNORE_TOKEN_ID + ), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled" + + def test_multiple_train_on_eot_settings( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + basic_dataset, + request, + ): + """Test different train_on_eot settings""" + LOG.info("Testing different train_on_eot settings") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + # Create a list to test different train_on_eot settings + test_settings = [ + ("none", lambda idx, is_assistant: False), # Never train on EOT + ("all", lambda idx, is_assistant: True), # Always train on EOT + ( + "turn", + lambda idx, is_assistant: is_assistant, + ), # Train on EOT after assistant turns + ("last", lambda idx, is_last: is_last), # Only train on last EOT + ] + + for setting, expected_train_func in test_settings: + LOG.info(f"Testing train_on_eot='{setting}'") + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "from", "content": "value"}, + field_messages="conversations", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + train_on_eot=setting, + eot_tokens=[ + tokenizer.eos_token + ], # Use eos_token as the EOT token for simplicity + ) + + res = strategy.tokenize_prompt(basic_dataset[0]) + turns = strategy.get_conversation_thread(basic_dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + eos_token_id = tokenizer.eos_token_id + eos_indices = [ + i for i, token_id in enumerate(input_ids) if token_id == eos_token_id + ] + + assert ( + len(eos_indices) > 0 + ), "Expected at least one EOS/EOT token in the input" + + # Check labeling for each EOS/EOT token + for idx, eos_idx in enumerate(eos_indices): + # Find which turn this EOS token belongs to + preceding_turn = None + for i, turn in enumerate(basic_dataset[0]["conversations"]): + start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i) + if ( + start_idx != -1 + and end_idx != -1 + and start_idx <= eos_idx <= end_idx + 1 + ): + preceding_turn = turn + break + + is_assistant = ( + preceding_turn is not None and preceding_turn["from"] == "assistant" + ) + is_last = idx == len(eos_indices) - 1 + + expected_label = not expected_train_func( + idx, is_assistant if setting != "last" else is_last + ) + + if expected_label: + assert ( + labels[eos_idx] == IGNORE_TOKEN_ID + ), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'" + else: + assert ( + labels[eos_idx] != IGNORE_TOKEN_ID + ), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"