feat: add eos_tokens and train_on_eot for chat_template EOT parsing (#2364)
* feat: add eos_tokens and train_on_eot for chat_template EOT parsing * fix: comments * chore: add some examples of tokens * feat: add new potential errors for chat_template to faq * feat: add examples for EOT handling * fix: change error to warning for missing EOS * fix: warning typo * feat: add tests for eot token handling * fix: remove broken caplog capture in test * fix: chattemplate strategy with kd missing eot changes
This commit is contained in:
@@ -187,7 +187,7 @@ datasets:
|
|||||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||||
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
|
# Note: If the below 5 fields are empty, defaults to training only on the last message.
|
||||||
|
|
||||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||||
roles_to_train: ["assistant"] # default
|
roles_to_train: ["assistant"] # default
|
||||||
@@ -196,7 +196,13 @@ datasets:
|
|||||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||||
# - last: train on the last EOS token in the conversation
|
# - last: train on the last EOS token in the conversation
|
||||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||||
train_on_eos: last
|
train_on_eos: turn
|
||||||
|
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
|
||||||
|
# - all: train on all EOT tokens
|
||||||
|
# - turn: train on the EOT token at the end of each trainable turn
|
||||||
|
# - last: train on the last EOT token in the conversation
|
||||||
|
# If not specified, defaults to the value of train_on_eos for backward compatibility.
|
||||||
|
train_on_eot:
|
||||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||||
message_field_training: training
|
message_field_training: training
|
||||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||||
@@ -279,8 +285,17 @@ process_reward_model:
|
|||||||
chat_template: tokenizer_default
|
chat_template: tokenizer_default
|
||||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||||
chat_template_jinja: null
|
chat_template_jinja: null
|
||||||
# Changes the default system message. Currently only supports chatml.
|
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
|
||||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
# These tokens mark the boundaries between conversation turns.
|
||||||
|
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
|
||||||
|
# If not specified, defaults to just the model's eos_token.
|
||||||
|
# This is useful for templates that use multiple delimiter tokens.
|
||||||
|
eot_tokens:
|
||||||
|
# - "</s>"
|
||||||
|
# - "[/INST]"
|
||||||
|
# - "[/SYSTEM_PROMPT]"
|
||||||
|
# Changes the default system message
|
||||||
|
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
@@ -665,8 +680,10 @@ special_tokens:
|
|||||||
# unk_token: "<unk>"
|
# unk_token: "<unk>"
|
||||||
# pad_token: "[PAD]"
|
# pad_token: "[PAD]"
|
||||||
|
|
||||||
# Add extra tokens.
|
# Optional[list[str]]. Add extra tokens to the tokenizer.
|
||||||
tokens:
|
tokens:
|
||||||
|
# - "<|startoftext|>"
|
||||||
|
# - "<|endoftext|>"
|
||||||
|
|
||||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||||
|
|||||||
@@ -4,18 +4,6 @@ description: Conversation format for supervised fine-tuning.
|
|||||||
order: 3
|
order: 3
|
||||||
---
|
---
|
||||||
|
|
||||||
## sharegpt
|
|
||||||
|
|
||||||
::: {.callout-important}
|
|
||||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
|
||||||
:::
|
|
||||||
|
|
||||||
## pygmalion
|
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
## chat_template
|
## chat_template
|
||||||
|
|
||||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||||
@@ -64,7 +52,7 @@ We recommend checking the below examples for other usecases.
|
|||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
@@ -109,10 +97,55 @@ datasets:
|
|||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
eot_tokens:
|
||||||
|
- "[/INST]"
|
||||||
|
# - "[/SYSTEM_PROMPT]"
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
# optional
|
||||||
|
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
||||||
|
|
||||||
|
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
||||||
|
:::
|
||||||
|
|
||||||
|
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
eot_tokens:
|
||||||
|
- "[/INST]"
|
||||||
|
# ...
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
train_on_eos: last
|
||||||
|
train_on_eot: turn
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||||
|
|
||||||
For a data sample that looks like:
|
For a data sample that looks like:
|
||||||
|
|
||||||
@@ -162,3 +195,15 @@ datasets:
|
|||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
## sharegpt
|
||||||
|
|
||||||
|
::: {.callout-important}
|
||||||
|
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## pygmalion
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
|
```
|
||||||
|
|||||||
34
docs/faq.qmd
34
docs/faq.qmd
@@ -73,10 +73,40 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: This is likely an empty turn.
|
> A: This is likely an empty turn.
|
||||||
|
|
||||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
|
||||||
|
|
||||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
> A: There can be two reasons:
|
||||||
|
|
||||||
|
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
|
||||||
|
|
||||||
|
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
|
||||||
|
|
||||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||||
|
|
||||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||||
|
|
||||||
|
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
|
||||||
|
|
||||||
|
> A: There can be two reasons:
|
||||||
|
|
||||||
|
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
|
||||||
|
|
||||||
|
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
|
||||||
|
|
||||||
|
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
|
||||||
|
|
||||||
|
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
|
||||||
|
|
||||||
|
**Q: `EOT token __ is encoded as multiple tokens.`**
|
||||||
|
|
||||||
|
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
|
||||||
|
|
||||||
|
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
|
||||||
|
|
||||||
|
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
|
||||||
|
|
||||||
|
**Q: If `eot_tokens: ` is not provided, what happens?**
|
||||||
|
|
||||||
|
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
|
||||||
|
|
||||||
|
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
sequence_len,
|
sequence_len,
|
||||||
roles_to_train=None,
|
roles_to_train=None,
|
||||||
train_on_eos=None,
|
train_on_eos=None,
|
||||||
|
train_on_eot=None,
|
||||||
|
eot_tokens=None,
|
||||||
logprobs_field="logprobs",
|
logprobs_field="logprobs",
|
||||||
gen_temperature=1.0,
|
gen_temperature=1.0,
|
||||||
kd_temperature=1.0,
|
kd_temperature=1.0,
|
||||||
@@ -50,6 +52,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
sequence_len,
|
sequence_len,
|
||||||
roles_to_train=roles_to_train,
|
roles_to_train=roles_to_train,
|
||||||
train_on_eos=train_on_eos,
|
train_on_eos=train_on_eos,
|
||||||
|
train_on_eot=train_on_eot,
|
||||||
|
eot_tokens=eot_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -222,10 +222,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
self,
|
self,
|
||||||
prompter: "ChatTemplatePrompter",
|
prompter: "ChatTemplatePrompter",
|
||||||
tokenizer,
|
tokenizer,
|
||||||
train_on_inputs,
|
train_on_inputs: bool,
|
||||||
sequence_len,
|
sequence_len: int,
|
||||||
roles_to_train=None,
|
roles_to_train: Optional[List[str]] = None,
|
||||||
train_on_eos=None,
|
train_on_eos: Optional[str] = None,
|
||||||
|
train_on_eot: Optional[str] = None,
|
||||||
|
eot_tokens: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
self.prompter: ChatTemplatePrompter = prompter
|
self.prompter: ChatTemplatePrompter = prompter
|
||||||
@@ -238,12 +240,87 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.train_on_eos = train_on_eos
|
self.train_on_eos = train_on_eos
|
||||||
|
# Backward compatibility, load from train_on_eos
|
||||||
|
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||||
|
|
||||||
|
# Default to eos_token if eot_tokens not provided
|
||||||
|
self.eot_tokens = (
|
||||||
|
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||||
|
)
|
||||||
|
|
||||||
self.images = "images"
|
self.images = "images"
|
||||||
|
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._validate_eot_and_eos_tokens()
|
||||||
|
|
||||||
|
def _validate_eot_and_eos_tokens(self):
|
||||||
|
"""
|
||||||
|
- Validates that EOT tokens (or eos_token) are in the chat_template
|
||||||
|
- Checks if EOT tokens are encoded as multiple tokens in the tokenizer.
|
||||||
|
- Checks for potential conflicts between train_on_eos and train_on_eot.
|
||||||
|
"""
|
||||||
|
if self.prompter.chat_template is None:
|
||||||
|
# Usually this should not happen
|
||||||
|
LOG.warning(
|
||||||
|
"No chat template provided, skipping EOT and EOS token validation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the EOT token is the same as the EOS token, we need to check differently
|
||||||
|
if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token:
|
||||||
|
# Check if the eos_token is in the chat_template or as a variable `eos_token`
|
||||||
|
# Note: we check for `eos_token` in the string, but it could possibly not be a variable
|
||||||
|
if (
|
||||||
|
self.tokenizer.eos_token not in self.prompter.chat_template
|
||||||
|
and "eos_token" not in self.prompter.chat_template
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a new list to store tokens that should be kept
|
||||||
|
valid_eot_tokens = []
|
||||||
|
for token in self.eot_tokens:
|
||||||
|
# Check if EOT token is in the chat_template
|
||||||
|
if token not in self.prompter.chat_template:
|
||||||
|
LOG.warning(f"EOT token '{token}' not found in chat_template.")
|
||||||
|
# Don't add to the valid tokens list
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_eot_tokens.append(token)
|
||||||
|
|
||||||
|
# Replace the original list with the filtered one
|
||||||
|
self.eot_tokens = valid_eot_tokens
|
||||||
|
|
||||||
|
for token in self.eot_tokens:
|
||||||
|
# If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens
|
||||||
|
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||||
|
if not token_ids:
|
||||||
|
raise ValueError(
|
||||||
|
"EOT token encoding failed. Please check if the token is valid and can be encoded."
|
||||||
|
)
|
||||||
|
if token_ids and len(token_ids) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config "
|
||||||
|
"or (recommended) override unused added_tokens via `added_tokens_overrides: `."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error
|
||||||
|
if (
|
||||||
|
self.tokenizer.eos_token in self.eot_tokens
|
||||||
|
and self.train_on_eos != self.train_on_eot
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot"
|
||||||
|
f"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}"
|
||||||
|
f"eot_tokens: {self.eot_tokens}"
|
||||||
|
f"eos_token: {self.tokenizer.eos_token}"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_batched(self) -> bool:
|
def supports_batched(self) -> bool:
|
||||||
# Let calling code know we can handle lists of examples
|
# Let calling code know we can handle lists of examples
|
||||||
@@ -287,6 +364,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if (
|
if (
|
||||||
not self.roles_to_train
|
not self.roles_to_train
|
||||||
and not self.train_on_eos
|
and not self.train_on_eos
|
||||||
|
and not self.train_on_eot
|
||||||
and not self.prompter.message_field_training # type: ignore
|
and not self.prompter.message_field_training # type: ignore
|
||||||
and not self.prompter.message_field_training_detail # type: ignore
|
and not self.prompter.message_field_training_detail # type: ignore
|
||||||
):
|
):
|
||||||
@@ -322,6 +400,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
last_eos_idx = -1
|
last_eos_idx = -1
|
||||||
|
last_eot_idx = -1
|
||||||
for index, turn in enumerate(turns):
|
for index, turn in enumerate(turns):
|
||||||
role = turn.get("role")
|
role = turn.get("role")
|
||||||
content = turn.get("content")
|
content = turn.get("content")
|
||||||
@@ -370,24 +449,45 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||||
|
|
||||||
# Handle EOS token
|
# Handle special tokens (EOT and EOS)
|
||||||
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
|
for token_type, find_func, train_option in [
|
||||||
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
|
("EOT", self.find_first_eot_token, self.train_on_eot),
|
||||||
last_eos_idx = eos_idx
|
("EOS", self.find_first_eos_token, self.train_on_eos),
|
||||||
if self.train_on_eos == "all" or (
|
]:
|
||||||
self.train_on_eos == "turn" and should_train
|
token_idx = find_func(input_ids, start_idx=turn_end_idx)
|
||||||
):
|
|
||||||
labels[eos_idx] = input_ids[eos_idx]
|
|
||||||
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
|
||||||
else:
|
|
||||||
LOG.debug(
|
|
||||||
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle 'last' option for train_on_eos
|
if (
|
||||||
if self.train_on_eos == "last" and last_eos_idx != -1:
|
token_idx != -1 and abs(token_idx - turn_end_idx) <= 3
|
||||||
labels[last_eos_idx] = input_ids[last_eos_idx]
|
): # Allow for some template padding
|
||||||
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
# Update the last token index
|
||||||
|
if token_type == "EOT": # nosec B105
|
||||||
|
last_eot_idx = token_idx
|
||||||
|
else:
|
||||||
|
last_eos_idx = token_idx
|
||||||
|
|
||||||
|
# Set labels if needed for this turn
|
||||||
|
if train_option == "all" or (
|
||||||
|
train_option == "turn" and should_train
|
||||||
|
):
|
||||||
|
labels[token_idx] = input_ids[token_idx]
|
||||||
|
LOG.debug(
|
||||||
|
f"{token_type} token set for training at index {token_idx}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.debug(
|
||||||
|
f"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle 'last' option for special tokens
|
||||||
|
for token_type, last_idx, train_option in [
|
||||||
|
("EOT", last_eot_idx, self.train_on_eot),
|
||||||
|
("EOS", last_eos_idx, self.train_on_eos),
|
||||||
|
]:
|
||||||
|
if train_option == "last" and last_idx != -1:
|
||||||
|
labels[last_idx] = input_ids[last_idx]
|
||||||
|
LOG.debug(
|
||||||
|
f"Last {token_type} token set for training at index {last_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
LOG.debug(f"Final labels: {labels}")
|
LOG.debug(f"Final labels: {labels}")
|
||||||
|
|
||||||
@@ -404,6 +504,25 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def find_first_eot_token(self, input_ids, start_idx):
|
||||||
|
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||||
|
# Get token IDs for all EOT tokens
|
||||||
|
eot_token_ids = []
|
||||||
|
for token in self.eot_tokens:
|
||||||
|
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||||
|
if len(token_ids) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
|
||||||
|
)
|
||||||
|
|
||||||
|
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
|
||||||
|
|
||||||
|
# Search for any of the EOT token IDs
|
||||||
|
for i in range(start_idx, len(input_ids)):
|
||||||
|
if input_ids[i] in eot_token_ids:
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||||
"""
|
"""
|
||||||
Locate the starting and ending indices of the specified turn in a conversation.
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
@@ -568,6 +687,8 @@ class StrategyLoader:
|
|||||||
"sequence_len": cfg.sequence_len,
|
"sequence_len": cfg.sequence_len,
|
||||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||||
|
"train_on_eot": ds_cfg.get("train_on_eot", None),
|
||||||
|
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -309,6 +309,7 @@ class AxolotlInputConfig(
|
|||||||
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
|
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
|
||||||
) | None = None
|
) | None = None
|
||||||
chat_template_jinja: str | None = None
|
chat_template_jinja: str | None = None
|
||||||
|
eot_tokens: list[str] | None = None
|
||||||
default_system_message: str | None = None
|
default_system_message: str | None = None
|
||||||
|
|
||||||
fix_untrained_tokens: int | list[int] | None = None
|
fix_untrained_tokens: int | list[int] | None = None
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
tests for chat_template prompt strategy
|
tests for chat_template prompt strategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@@ -53,14 +55,6 @@ class TestChatTemplateConfigurations:
|
|||||||
Test class for various configurations of ChatTemplateStrategy.
|
Test class for various configurations of ChatTemplateStrategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def find_sublist(full_list, sub_list):
|
|
||||||
token_count = len(sub_list)
|
|
||||||
for index in range(len(full_list) - token_count + 1):
|
|
||||||
if full_list[index : index + token_count] == sub_list:
|
|
||||||
return index
|
|
||||||
return -1
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def setup_tokenizer(
|
def setup_tokenizer(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
@@ -68,6 +62,7 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template_jinja=None,
|
chat_template_jinja=None,
|
||||||
eos_token=None,
|
eos_token=None,
|
||||||
request=None,
|
request=None,
|
||||||
|
eot_token=None,
|
||||||
) -> tuple[PreTrainedTokenizer, str]:
|
) -> tuple[PreTrainedTokenizer, str]:
|
||||||
"""
|
"""
|
||||||
Helper function to set up the tokenizer and chat template for the test.
|
Helper function to set up the tokenizer and chat template for the test.
|
||||||
@@ -88,6 +83,10 @@ class TestChatTemplateConfigurations:
|
|||||||
"CodeLlamaTokenizerFast",
|
"CodeLlamaTokenizerFast",
|
||||||
):
|
):
|
||||||
tokenizer.update_post_processor()
|
tokenizer.update_post_processor()
|
||||||
|
|
||||||
|
if eot_token:
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||||
|
|
||||||
return tokenizer, chat_template_jinja
|
return tokenizer, chat_template_jinja
|
||||||
|
|
||||||
def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):
|
def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):
|
||||||
@@ -974,3 +973,311 @@ class TestChatTemplateConfigurations:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_eot_tokens_conflict_with_eos_token(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
basic_dataset, # pylint: disable=unused-argument
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
|
||||||
|
LOG.info(
|
||||||
|
"Testing conflict between eot_tokens containing eos_token and train_on_eot/train_on_eos mismatch"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a situation where eot_tokens contains eos_token
|
||||||
|
eot_tokens = [
|
||||||
|
tokenizer.eos_token,
|
||||||
|
"[/INST]",
|
||||||
|
] # Deliberately including eos_token
|
||||||
|
|
||||||
|
# Create conflicting train_on_eos and train_on_eot settings
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=".*eos_token is in eot_tokens and train_on_eos != train_on_eot.*",
|
||||||
|
):
|
||||||
|
ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
|
field_messages="conversations",
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="none", # Setting to none
|
||||||
|
train_on_eot="turn", # Different from train_on_eos
|
||||||
|
eot_tokens=eot_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_eot_token_backward_compatibility(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
basic_dataset, # pylint: disable=unused-argument
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
"""Test that eot_tokens inherits from eos_token when not specified"""
|
||||||
|
LOG.info("Testing backward compatibility that eot_token inherits eos_token")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
|
field_messages="conversations",
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="turn", # Setting train_on_eos to "turn"
|
||||||
|
)
|
||||||
|
|
||||||
|
# In backward compatibility mode, eot_tokens should be derived from eos_token
|
||||||
|
assert strategy.eot_tokens == [
|
||||||
|
tokenizer.eos_token
|
||||||
|
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||||
|
assert (
|
||||||
|
strategy.train_on_eot == "turn"
|
||||||
|
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||||
|
|
||||||
|
def test_token_not_in_template(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
basic_dataset,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
"""Test runs even when tokens are not found in the template"""
|
||||||
|
LOG.info("Testing runs even when tokens are not found in template")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a non-existent token that definitely won't be in the template
|
||||||
|
non_existent_token = "[DEFINITELY_NOT_IN_TEMPLATE]"
|
||||||
|
tokenizer.add_special_tokens(
|
||||||
|
{"additional_special_tokens": [non_existent_token]}
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
|
field_messages="conversations",
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
eot_tokens=[non_existent_token],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force template check by calling tokenize_prompt
|
||||||
|
strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
|
||||||
|
# We can also check that a warning was logged, but there's
|
||||||
|
# caplog conflicts when running with other tests
|
||||||
|
# assert any(
|
||||||
|
# "not found in chat_template" in record.message for record in self._caplog.records
|
||||||
|
# ), "Expected warning about token not found in template was not logged"
|
||||||
|
|
||||||
|
def test_custom_eot_tokens(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token, # pylint: disable=unused-argument
|
||||||
|
basic_dataset,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
"""Test with custom EOT tokens to ensure proper masking and training"""
|
||||||
|
LOG.info("Testing with custom EOT tokens")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, None, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add custom EOT tokens to the tokenizer
|
||||||
|
custom_eot = "[EOT]"
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [custom_eot]})
|
||||||
|
|
||||||
|
# Create a custom chat template that uses our EOT token
|
||||||
|
custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content'] }}{% elif message['role'] == 'user' %}User: {{ message['content'] }}{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}[EOT]{% endif %}{% endfor %}"""
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=custom_template,
|
||||||
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
|
field_messages="conversations",
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eot="turn", # Train on EOT token after each turn
|
||||||
|
eot_tokens=[custom_eot],
|
||||||
|
)
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Find indices of the EOT token
|
||||||
|
eot_token_id = tokenizer.convert_tokens_to_ids(custom_eot)
|
||||||
|
eot_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eot_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eot_indices) > 0, "Expected at least one EOT token in the input"
|
||||||
|
|
||||||
|
# Verify labeling for EOT tokens based on role
|
||||||
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
|
assistant_turn_indices = []
|
||||||
|
non_assistant_turn_indices = []
|
||||||
|
|
||||||
|
for i, turn in enumerate(basic_dataset[0]["conversations"]):
|
||||||
|
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
|
||||||
|
if start_idx != -1 and end_idx != -1: # If turn is found
|
||||||
|
if turn["from"] == "assistant":
|
||||||
|
assistant_turn_indices.append((start_idx, end_idx))
|
||||||
|
else:
|
||||||
|
non_assistant_turn_indices.append((start_idx, end_idx))
|
||||||
|
|
||||||
|
# Check EOT tokens after assistant turns are labeled
|
||||||
|
for eot_idx in eot_indices:
|
||||||
|
is_after_assistant = any(
|
||||||
|
start_idx <= eot_idx <= end_idx + 1 # +1 to include the EOT token
|
||||||
|
for start_idx, end_idx in assistant_turn_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_after_assistant:
|
||||||
|
assert (
|
||||||
|
labels[eot_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
labels[eot_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||||
|
|
||||||
|
def test_multiple_train_on_eot_settings(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template,
|
||||||
|
chat_template_jinja,
|
||||||
|
eos_token,
|
||||||
|
basic_dataset,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
"""Test different train_on_eot settings"""
|
||||||
|
LOG.info("Testing different train_on_eot settings")
|
||||||
|
|
||||||
|
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||||
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a list to test different train_on_eot settings
|
||||||
|
test_settings = [
|
||||||
|
("none", lambda idx, is_assistant: False), # Never train on EOT
|
||||||
|
("all", lambda idx, is_assistant: True), # Always train on EOT
|
||||||
|
(
|
||||||
|
"turn",
|
||||||
|
lambda idx, is_assistant: is_assistant,
|
||||||
|
), # Train on EOT after assistant turns
|
||||||
|
("last", lambda idx, is_last: is_last), # Only train on last EOT
|
||||||
|
]
|
||||||
|
|
||||||
|
for setting, expected_train_func in test_settings:
|
||||||
|
LOG.info(f"Testing train_on_eot='{setting}'")
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=chat_template_jinja
|
||||||
|
),
|
||||||
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
|
field_messages="conversations",
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eot=setting,
|
||||||
|
eot_tokens=[
|
||||||
|
tokenizer.eos_token
|
||||||
|
], # Use eos_token as the EOT token for simplicity
|
||||||
|
)
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(eos_indices) > 0
|
||||||
|
), "Expected at least one EOS/EOT token in the input"
|
||||||
|
|
||||||
|
# Check labeling for each EOS/EOT token
|
||||||
|
for idx, eos_idx in enumerate(eos_indices):
|
||||||
|
# Find which turn this EOS token belongs to
|
||||||
|
preceding_turn = None
|
||||||
|
for i, turn in enumerate(basic_dataset[0]["conversations"]):
|
||||||
|
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
|
||||||
|
if (
|
||||||
|
start_idx != -1
|
||||||
|
and end_idx != -1
|
||||||
|
and start_idx <= eos_idx <= end_idx + 1
|
||||||
|
):
|
||||||
|
preceding_turn = turn
|
||||||
|
break
|
||||||
|
|
||||||
|
is_assistant = (
|
||||||
|
preceding_turn is not None and preceding_turn["from"] == "assistant"
|
||||||
|
)
|
||||||
|
is_last = idx == len(eos_indices) - 1
|
||||||
|
|
||||||
|
expected_label = not expected_train_func(
|
||||||
|
idx, is_assistant if setting != "last" else is_last
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected_label:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||||
|
|||||||
Reference in New Issue
Block a user