feat: add eos_tokens and train_on_eot for chat_template EOT parsing (#2364)
* feat: add eos_tokens and train_on_eot for chat_template EOT parsing * fix: comments * chore: add some examples of tokens * feat: add new potential errors for chat_template to faq * feat: add examples for EOT handling * fix: change error to warning for missing EOS * fix: warning typo * feat: add tests for eot token handling * fix: remove broken caplog capture in test * fix: chattemplate strategy with kd missing eot changes
This commit is contained in:
@@ -35,6 +35,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
train_on_eot=None,
|
||||
eot_tokens=None,
|
||||
logprobs_field="logprobs",
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
@@ -50,6 +52,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
train_on_eot=train_on_eot,
|
||||
eot_tokens=eot_tokens,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -222,10 +222,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
train_on_inputs: bool,
|
||||
sequence_len: int,
|
||||
roles_to_train: Optional[List[str]] = None,
|
||||
train_on_eos: Optional[str] = None,
|
||||
train_on_eot: Optional[str] = None,
|
||||
eot_tokens: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
@@ -238,12 +240,87 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
]
|
||||
|
||||
self.train_on_eos = train_on_eos
|
||||
# Backward compatibility, load from train_on_eos
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = (
|
||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||
)
|
||||
|
||||
self.images = "images"
|
||||
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
self._validate_eot_and_eos_tokens()
|
||||
|
||||
def _validate_eot_and_eos_tokens(self):
|
||||
"""
|
||||
- Validates that EOT tokens (or eos_token) are in the chat_template
|
||||
- Checks if EOT tokens are encoded as multiple tokens in the tokenizer.
|
||||
- Checks for potential conflicts between train_on_eos and train_on_eot.
|
||||
"""
|
||||
if self.prompter.chat_template is None:
|
||||
# Usually this should not happen
|
||||
LOG.warning(
|
||||
"No chat template provided, skipping EOT and EOS token validation"
|
||||
)
|
||||
return
|
||||
|
||||
# If the EOT token is the same as the EOS token, we need to check differently
|
||||
if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token:
|
||||
# Check if the eos_token is in the chat_template or as a variable `eos_token`
|
||||
# Note: we check for `eos_token` in the string, but it could possibly not be a variable
|
||||
if (
|
||||
self.tokenizer.eos_token not in self.prompter.chat_template
|
||||
and "eos_token" not in self.prompter.chat_template
|
||||
):
|
||||
LOG.warning(
|
||||
f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct."
|
||||
)
|
||||
return
|
||||
|
||||
# Create a new list to store tokens that should be kept
|
||||
valid_eot_tokens = []
|
||||
for token in self.eot_tokens:
|
||||
# Check if EOT token is in the chat_template
|
||||
if token not in self.prompter.chat_template:
|
||||
LOG.warning(f"EOT token '{token}' not found in chat_template.")
|
||||
# Don't add to the valid tokens list
|
||||
continue
|
||||
|
||||
valid_eot_tokens.append(token)
|
||||
|
||||
# Replace the original list with the filtered one
|
||||
self.eot_tokens = valid_eot_tokens
|
||||
|
||||
for token in self.eot_tokens:
|
||||
# If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if not token_ids:
|
||||
raise ValueError(
|
||||
"EOT token encoding failed. Please check if the token is valid and can be encoded."
|
||||
)
|
||||
if token_ids and len(token_ids) > 1:
|
||||
raise ValueError(
|
||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config "
|
||||
"or (recommended) override unused added_tokens via `added_tokens_overrides: `."
|
||||
)
|
||||
|
||||
# If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error
|
||||
if (
|
||||
self.tokenizer.eos_token in self.eot_tokens
|
||||
and self.train_on_eos != self.train_on_eot
|
||||
):
|
||||
raise ValueError(
|
||||
"Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot"
|
||||
f"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}"
|
||||
f"eot_tokens: {self.eot_tokens}"
|
||||
f"eos_token: {self.tokenizer.eos_token}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# Let calling code know we can handle lists of examples
|
||||
@@ -287,6 +364,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if (
|
||||
not self.roles_to_train
|
||||
and not self.train_on_eos
|
||||
and not self.train_on_eot
|
||||
and not self.prompter.message_field_training # type: ignore
|
||||
and not self.prompter.message_field_training_detail # type: ignore
|
||||
):
|
||||
@@ -322,6 +400,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
last_eot_idx = -1
|
||||
for index, turn in enumerate(turns):
|
||||
role = turn.get("role")
|
||||
content = turn.get("content")
|
||||
@@ -370,24 +449,45 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
# Handle EOS token
|
||||
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
|
||||
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
|
||||
last_eos_idx = eos_idx
|
||||
if self.train_on_eos == "all" or (
|
||||
self.train_on_eos == "turn" and should_train
|
||||
):
|
||||
labels[eos_idx] = input_ids[eos_idx]
|
||||
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
||||
else:
|
||||
LOG.debug(
|
||||
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
||||
)
|
||||
# Handle special tokens (EOT and EOS)
|
||||
for token_type, find_func, train_option in [
|
||||
("EOT", self.find_first_eot_token, self.train_on_eot),
|
||||
("EOS", self.find_first_eos_token, self.train_on_eos),
|
||||
]:
|
||||
token_idx = find_func(input_ids, start_idx=turn_end_idx)
|
||||
|
||||
# Handle 'last' option for train_on_eos
|
||||
if self.train_on_eos == "last" and last_eos_idx != -1:
|
||||
labels[last_eos_idx] = input_ids[last_eos_idx]
|
||||
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
||||
if (
|
||||
token_idx != -1 and abs(token_idx - turn_end_idx) <= 3
|
||||
): # Allow for some template padding
|
||||
# Update the last token index
|
||||
if token_type == "EOT": # nosec B105
|
||||
last_eot_idx = token_idx
|
||||
else:
|
||||
last_eos_idx = token_idx
|
||||
|
||||
# Set labels if needed for this turn
|
||||
if train_option == "all" or (
|
||||
train_option == "turn" and should_train
|
||||
):
|
||||
labels[token_idx] = input_ids[token_idx]
|
||||
LOG.debug(
|
||||
f"{token_type} token set for training at index {token_idx}"
|
||||
)
|
||||
else:
|
||||
LOG.debug(
|
||||
f"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}"
|
||||
)
|
||||
|
||||
# Handle 'last' option for special tokens
|
||||
for token_type, last_idx, train_option in [
|
||||
("EOT", last_eot_idx, self.train_on_eot),
|
||||
("EOS", last_eos_idx, self.train_on_eos),
|
||||
]:
|
||||
if train_option == "last" and last_idx != -1:
|
||||
labels[last_idx] = input_ids[last_idx]
|
||||
LOG.debug(
|
||||
f"Last {token_type} token set for training at index {last_idx}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
|
||||
@@ -404,6 +504,25 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# Get token IDs for all EOT tokens
|
||||
eot_token_ids = []
|
||||
for token in self.eot_tokens:
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if len(token_ids) != 1:
|
||||
raise ValueError(
|
||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
|
||||
)
|
||||
|
||||
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
|
||||
|
||||
# Search for any of the EOT token IDs
|
||||
for i in range(start_idx, len(input_ids)):
|
||||
if input_ids[i] in eot_token_ids:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
@@ -568,6 +687,8 @@ class StrategyLoader:
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
"train_on_eot": ds_cfg.get("train_on_eot", None),
|
||||
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
|
||||
}
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -309,6 +309,7 @@ class AxolotlInputConfig(
|
||||
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
|
||||
) | None = None
|
||||
chat_template_jinja: str | None = None
|
||||
eot_tokens: list[str] | None = None
|
||||
default_system_message: str | None = None
|
||||
|
||||
fix_untrained_tokens: int | list[int] | None = None
|
||||
|
||||
Reference in New Issue
Block a user