feat: add eos_tokens and train_on_eot for chat_template EOT parsing (#2364)
* feat: add eos_tokens and train_on_eot for chat_template EOT parsing * fix: comments * chore: add some examples of tokens * feat: add new potential errors for chat_template to faq * feat: add examples for EOT handling * fix: change error to warning for missing EOS * fix: warning typo * feat: add tests for eot token handling * fix: remove broken caplog capture in test * fix: chattemplate strategy with kd missing eot changes
This commit is contained in:
@@ -187,7 +187,7 @@ datasets:
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
|
||||
# Note: If the below 5 fields are empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["assistant"] # default
|
||||
@@ -196,7 +196,13 @@ datasets:
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
train_on_eos: last
|
||||
train_on_eos: turn
|
||||
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOT tokens
|
||||
# - turn: train on the EOT token at the end of each trainable turn
|
||||
# - last: train on the last EOT token in the conversation
|
||||
# If not specified, defaults to the value of train_on_eos for backward compatibility.
|
||||
train_on_eot:
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
@@ -279,8 +285,17 @@ process_reward_model:
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Changes the default system message. Currently only supports chatml.
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
||||
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
|
||||
# These tokens mark the boundaries between conversation turns.
|
||||
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
|
||||
# If not specified, defaults to just the model's eos_token.
|
||||
# This is useful for templates that use multiple delimiter tokens.
|
||||
eot_tokens:
|
||||
# - "</s>"
|
||||
# - "[/INST]"
|
||||
# - "[/SYSTEM_PROMPT]"
|
||||
# Changes the default system message
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
@@ -665,8 +680,10 @@ special_tokens:
|
||||
# unk_token: "<unk>"
|
||||
# pad_token: "[PAD]"
|
||||
|
||||
# Add extra tokens.
|
||||
# Optional[list[str]]. Add extra tokens to the tokenizer.
|
||||
tokens:
|
||||
# - "<|startoftext|>"
|
||||
# - "<|endoftext|>"
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
|
||||
@@ -4,18 +4,6 @@ description: Conversation format for supervised fine-tuning.
|
||||
order: 3
|
||||
---
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## chat_template
|
||||
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
@@ -64,7 +52,7 @@ We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -109,10 +97,55 @@ datasets:
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||
:::
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
- "[/INST]"
|
||||
# - "[/SYSTEM_PROMPT]"
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
# optional
|
||||
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
||||
|
||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
||||
:::
|
||||
|
||||
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
- "[/INST]"
|
||||
# ...
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
train_on_eos: last
|
||||
train_on_eot: turn
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
|
||||
:::
|
||||
|
||||
|
||||
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
|
||||
@@ -162,3 +195,15 @@ datasets:
|
||||
::: {.callout-tip}
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
34
docs/faq.qmd
34
docs/faq.qmd
@@ -73,10 +73,40 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is likely an empty turn.
|
||||
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
|
||||
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
> A: There can be two reasons:
|
||||
|
||||
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
|
||||
|
||||
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
|
||||
|
||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||
|
||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||
|
||||
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
|
||||
|
||||
> A: There can be two reasons:
|
||||
|
||||
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
|
||||
|
||||
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
|
||||
|
||||
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
|
||||
|
||||
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
|
||||
|
||||
**Q: `EOT token __ is encoded as multiple tokens.`**
|
||||
|
||||
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
|
||||
|
||||
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
|
||||
|
||||
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
|
||||
|
||||
**Q: If `eot_tokens: ` is not provided, what happens?**
|
||||
|
||||
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
|
||||
|
||||
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
|
||||
|
||||
@@ -35,6 +35,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
train_on_eot=None,
|
||||
eot_tokens=None,
|
||||
logprobs_field="logprobs",
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
@@ -50,6 +52,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
train_on_eot=train_on_eot,
|
||||
eot_tokens=eot_tokens,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -222,10 +222,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
train_on_inputs: bool,
|
||||
sequence_len: int,
|
||||
roles_to_train: Optional[List[str]] = None,
|
||||
train_on_eos: Optional[str] = None,
|
||||
train_on_eot: Optional[str] = None,
|
||||
eot_tokens: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
@@ -238,12 +240,87 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
]
|
||||
|
||||
self.train_on_eos = train_on_eos
|
||||
# Backward compatibility, load from train_on_eos
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = (
|
||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||
)
|
||||
|
||||
self.images = "images"
|
||||
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
self._validate_eot_and_eos_tokens()
|
||||
|
||||
def _validate_eot_and_eos_tokens(self):
|
||||
"""
|
||||
- Validates that EOT tokens (or eos_token) are in the chat_template
|
||||
- Checks if EOT tokens are encoded as multiple tokens in the tokenizer.
|
||||
- Checks for potential conflicts between train_on_eos and train_on_eot.
|
||||
"""
|
||||
if self.prompter.chat_template is None:
|
||||
# Usually this should not happen
|
||||
LOG.warning(
|
||||
"No chat template provided, skipping EOT and EOS token validation"
|
||||
)
|
||||
return
|
||||
|
||||
# If the EOT token is the same as the EOS token, we need to check differently
|
||||
if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token:
|
||||
# Check if the eos_token is in the chat_template or as a variable `eos_token`
|
||||
# Note: we check for `eos_token` in the string, but it could possibly not be a variable
|
||||
if (
|
||||
self.tokenizer.eos_token not in self.prompter.chat_template
|
||||
and "eos_token" not in self.prompter.chat_template
|
||||
):
|
||||
LOG.warning(
|
||||
f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct."
|
||||
)
|
||||
return
|
||||
|
||||
# Create a new list to store tokens that should be kept
|
||||
valid_eot_tokens = []
|
||||
for token in self.eot_tokens:
|
||||
# Check if EOT token is in the chat_template
|
||||
if token not in self.prompter.chat_template:
|
||||
LOG.warning(f"EOT token '{token}' not found in chat_template.")
|
||||
# Don't add to the valid tokens list
|
||||
continue
|
||||
|
||||
valid_eot_tokens.append(token)
|
||||
|
||||
# Replace the original list with the filtered one
|
||||
self.eot_tokens = valid_eot_tokens
|
||||
|
||||
for token in self.eot_tokens:
|
||||
# If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if not token_ids:
|
||||
raise ValueError(
|
||||
"EOT token encoding failed. Please check if the token is valid and can be encoded."
|
||||
)
|
||||
if token_ids and len(token_ids) > 1:
|
||||
raise ValueError(
|
||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config "
|
||||
"or (recommended) override unused added_tokens via `added_tokens_overrides: `."
|
||||
)
|
||||
|
||||
# If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error
|
||||
if (
|
||||
self.tokenizer.eos_token in self.eot_tokens
|
||||
and self.train_on_eos != self.train_on_eot
|
||||
):
|
||||
raise ValueError(
|
||||
"Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot"
|
||||
f"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}"
|
||||
f"eot_tokens: {self.eot_tokens}"
|
||||
f"eos_token: {self.tokenizer.eos_token}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# Let calling code know we can handle lists of examples
|
||||
@@ -287,6 +364,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if (
|
||||
not self.roles_to_train
|
||||
and not self.train_on_eos
|
||||
and not self.train_on_eot
|
||||
and not self.prompter.message_field_training # type: ignore
|
||||
and not self.prompter.message_field_training_detail # type: ignore
|
||||
):
|
||||
@@ -322,6 +400,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
last_eot_idx = -1
|
||||
for index, turn in enumerate(turns):
|
||||
role = turn.get("role")
|
||||
content = turn.get("content")
|
||||
@@ -370,24 +449,45 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
# Handle EOS token
|
||||
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
|
||||
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
|
||||
last_eos_idx = eos_idx
|
||||
if self.train_on_eos == "all" or (
|
||||
self.train_on_eos == "turn" and should_train
|
||||
):
|
||||
labels[eos_idx] = input_ids[eos_idx]
|
||||
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
||||
else:
|
||||
LOG.debug(
|
||||
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
||||
)
|
||||
# Handle special tokens (EOT and EOS)
|
||||
for token_type, find_func, train_option in [
|
||||
("EOT", self.find_first_eot_token, self.train_on_eot),
|
||||
("EOS", self.find_first_eos_token, self.train_on_eos),
|
||||
]:
|
||||
token_idx = find_func(input_ids, start_idx=turn_end_idx)
|
||||
|
||||
# Handle 'last' option for train_on_eos
|
||||
if self.train_on_eos == "last" and last_eos_idx != -1:
|
||||
labels[last_eos_idx] = input_ids[last_eos_idx]
|
||||
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
||||
if (
|
||||
token_idx != -1 and abs(token_idx - turn_end_idx) <= 3
|
||||
): # Allow for some template padding
|
||||
# Update the last token index
|
||||
if token_type == "EOT": # nosec B105
|
||||
last_eot_idx = token_idx
|
||||
else:
|
||||
last_eos_idx = token_idx
|
||||
|
||||
# Set labels if needed for this turn
|
||||
if train_option == "all" or (
|
||||
train_option == "turn" and should_train
|
||||
):
|
||||
labels[token_idx] = input_ids[token_idx]
|
||||
LOG.debug(
|
||||
f"{token_type} token set for training at index {token_idx}"
|
||||
)
|
||||
else:
|
||||
LOG.debug(
|
||||
f"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}"
|
||||
)
|
||||
|
||||
# Handle 'last' option for special tokens
|
||||
for token_type, last_idx, train_option in [
|
||||
("EOT", last_eot_idx, self.train_on_eot),
|
||||
("EOS", last_eos_idx, self.train_on_eos),
|
||||
]:
|
||||
if train_option == "last" and last_idx != -1:
|
||||
labels[last_idx] = input_ids[last_idx]
|
||||
LOG.debug(
|
||||
f"Last {token_type} token set for training at index {last_idx}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
|
||||
@@ -404,6 +504,25 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# Get token IDs for all EOT tokens
|
||||
eot_token_ids = []
|
||||
for token in self.eot_tokens:
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if len(token_ids) != 1:
|
||||
raise ValueError(
|
||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
|
||||
)
|
||||
|
||||
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
|
||||
|
||||
# Search for any of the EOT token IDs
|
||||
for i in range(start_idx, len(input_ids)):
|
||||
if input_ids[i] in eot_token_ids:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
@@ -568,6 +687,8 @@ class StrategyLoader:
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
"train_on_eot": ds_cfg.get("train_on_eot", None),
|
||||
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
|
||||
}
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -309,6 +309,7 @@ class AxolotlInputConfig(
|
||||
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
|
||||
) | None = None
|
||||
chat_template_jinja: str | None = None
|
||||
eot_tokens: list[str] | None = None
|
||||
default_system_message: str | None = None
|
||||
|
||||
fix_untrained_tokens: int | list[int] | None = None
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -53,14 +55,6 @@ class TestChatTemplateConfigurations:
|
||||
Test class for various configurations of ChatTemplateStrategy.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def find_sublist(full_list, sub_list):
|
||||
token_count = len(sub_list)
|
||||
for index in range(len(full_list) - token_count + 1):
|
||||
if full_list[index : index + token_count] == sub_list:
|
||||
return index
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def setup_tokenizer(
|
||||
tokenizer_name,
|
||||
@@ -68,6 +62,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template_jinja=None,
|
||||
eos_token=None,
|
||||
request=None,
|
||||
eot_token=None,
|
||||
) -> tuple[PreTrainedTokenizer, str]:
|
||||
"""
|
||||
Helper function to set up the tokenizer and chat template for the test.
|
||||
@@ -88,6 +83,10 @@ class TestChatTemplateConfigurations:
|
||||
"CodeLlamaTokenizerFast",
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
|
||||
if eot_token:
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||
|
||||
return tokenizer, chat_template_jinja
|
||||
|
||||
def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):
|
||||
@@ -974,3 +973,311 @@ class TestChatTemplateConfigurations:
|
||||
raise ValueError(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
def test_eot_tokens_conflict_with_eos_token(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
request,
|
||||
):
|
||||
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
|
||||
LOG.info(
|
||||
"Testing conflict between eot_tokens containing eos_token and train_on_eot/train_on_eos mismatch"
|
||||
)
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
# Create a situation where eot_tokens contains eos_token
|
||||
eot_tokens = [
|
||||
tokenizer.eos_token,
|
||||
"[/INST]",
|
||||
] # Deliberately including eos_token
|
||||
|
||||
# Create conflicting train_on_eos and train_on_eot settings
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=".*eos_token is in eot_tokens and train_on_eos != train_on_eot.*",
|
||||
):
|
||||
ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="none", # Setting to none
|
||||
train_on_eot="turn", # Different from train_on_eos
|
||||
eot_tokens=eot_tokens,
|
||||
)
|
||||
|
||||
def test_eot_token_backward_compatibility(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
request,
|
||||
):
|
||||
"""Test that eot_tokens inherits from eos_token when not specified"""
|
||||
LOG.info("Testing backward compatibility that eot_token inherits eos_token")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="turn", # Setting train_on_eos to "turn"
|
||||
)
|
||||
|
||||
# In backward compatibility mode, eot_tokens should be derived from eos_token
|
||||
assert strategy.eot_tokens == [
|
||||
tokenizer.eos_token
|
||||
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||
assert (
|
||||
strategy.train_on_eot == "turn"
|
||||
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||
|
||||
def test_token_not_in_template(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test runs even when tokens are not found in the template"""
|
||||
LOG.info("Testing runs even when tokens are not found in template")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
# Create a non-existent token that definitely won't be in the template
|
||||
non_existent_token = "[DEFINITELY_NOT_IN_TEMPLATE]"
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": [non_existent_token]}
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
eot_tokens=[non_existent_token],
|
||||
)
|
||||
|
||||
# Force template check by calling tokenize_prompt
|
||||
strategy.tokenize_prompt(basic_dataset[0])
|
||||
|
||||
# We can also check that a warning was logged, but there's
|
||||
# caplog conflicts when running with other tests
|
||||
# assert any(
|
||||
# "not found in chat_template" in record.message for record in self._caplog.records
|
||||
# ), "Expected warning about token not found in template was not logged"
|
||||
|
||||
def test_custom_eot_tokens(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token, # pylint: disable=unused-argument
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test with custom EOT tokens to ensure proper masking and training"""
|
||||
LOG.info("Testing with custom EOT tokens")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, None, request
|
||||
)
|
||||
|
||||
# Add custom EOT tokens to the tokenizer
|
||||
custom_eot = "[EOT]"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [custom_eot]})
|
||||
|
||||
# Create a custom chat template that uses our EOT token
|
||||
custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content'] }}{% elif message['role'] == 'user' %}User: {{ message['content'] }}{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}[EOT]{% endif %}{% endfor %}"""
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=custom_template,
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eot="turn", # Train on EOT token after each turn
|
||||
eot_tokens=[custom_eot],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Find indices of the EOT token
|
||||
eot_token_id = tokenizer.convert_tokens_to_ids(custom_eot)
|
||||
eot_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eot_token_id
|
||||
]
|
||||
|
||||
assert len(eot_indices) > 0, "Expected at least one EOT token in the input"
|
||||
|
||||
# Verify labeling for EOT tokens based on role
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
assistant_turn_indices = []
|
||||
non_assistant_turn_indices = []
|
||||
|
||||
for i, turn in enumerate(basic_dataset[0]["conversations"]):
|
||||
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
|
||||
if start_idx != -1 and end_idx != -1: # If turn is found
|
||||
if turn["from"] == "assistant":
|
||||
assistant_turn_indices.append((start_idx, end_idx))
|
||||
else:
|
||||
non_assistant_turn_indices.append((start_idx, end_idx))
|
||||
|
||||
# Check EOT tokens after assistant turns are labeled
|
||||
for eot_idx in eot_indices:
|
||||
is_after_assistant = any(
|
||||
start_idx <= eot_idx <= end_idx + 1 # +1 to include the EOT token
|
||||
for start_idx, end_idx in assistant_turn_indices
|
||||
)
|
||||
|
||||
if is_after_assistant:
|
||||
assert (
|
||||
labels[eot_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||
else:
|
||||
assert (
|
||||
labels[eot_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||
|
||||
def test_multiple_train_on_eot_settings(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test different train_on_eot settings"""
|
||||
LOG.info("Testing different train_on_eot settings")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
# Create a list to test different train_on_eot settings
|
||||
test_settings = [
|
||||
("none", lambda idx, is_assistant: False), # Never train on EOT
|
||||
("all", lambda idx, is_assistant: True), # Always train on EOT
|
||||
(
|
||||
"turn",
|
||||
lambda idx, is_assistant: is_assistant,
|
||||
), # Train on EOT after assistant turns
|
||||
("last", lambda idx, is_last: is_last), # Only train on last EOT
|
||||
]
|
||||
|
||||
for setting, expected_train_func in test_settings:
|
||||
LOG.info(f"Testing train_on_eot='{setting}'")
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eot=setting,
|
||||
eot_tokens=[
|
||||
tokenizer.eos_token
|
||||
], # Use eos_token as the EOT token for simplicity
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_indices = [
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert (
|
||||
len(eos_indices) > 0
|
||||
), "Expected at least one EOS/EOT token in the input"
|
||||
|
||||
# Check labeling for each EOS/EOT token
|
||||
for idx, eos_idx in enumerate(eos_indices):
|
||||
# Find which turn this EOS token belongs to
|
||||
preceding_turn = None
|
||||
for i, turn in enumerate(basic_dataset[0]["conversations"]):
|
||||
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
|
||||
if (
|
||||
start_idx != -1
|
||||
and end_idx != -1
|
||||
and start_idx <= eos_idx <= end_idx + 1
|
||||
):
|
||||
preceding_turn = turn
|
||||
break
|
||||
|
||||
is_assistant = (
|
||||
preceding_turn is not None and preceding_turn["from"] == "assistant"
|
||||
)
|
||||
is_last = idx == len(eos_indices) - 1
|
||||
|
||||
expected_label = not expected_train_func(
|
||||
idx, is_assistant if setting != "last" else is_last
|
||||
)
|
||||
|
||||
if expected_label:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||
else:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
|
||||
Reference in New Issue
Block a user