diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index a7f749f3b..a943a1448 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self._validate_eot_and_eos_tokens() + # Pre-cache EOT token IDs to avoid re-encoding on every call + self._eot_token_ids = set() + for token in self.eot_tokens: + token_ids = self.tokenizer.encode(token, add_special_tokens=False) + if len(token_ids) == 1: + self._eot_token_ids.add(token_ids[0]) + def _validate_eot_and_eos_tokens(self): """ - Validates that EOT tokens (or eos_token) are in the chat_template @@ -632,20 +639,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def find_first_eot_token(self, input_ids, start_idx): """Find the first EOT token in the input_ids starting from start_idx.""" - # Get token IDs for all EOT tokens - eot_token_ids = [] - for token in self.eot_tokens: - token_ids = self.tokenizer.encode(token, add_special_tokens=False) - if len(token_ids) != 1: - raise ValueError( - f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config." - ) - - eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple - - # Search for any of the EOT token IDs + # Use pre-cached EOT token IDs (computed once in __init__) for i in range(start_idx, len(input_ids)): - if input_ids[i] in eot_token_ids: + if input_ids[i] in self._eot_token_ids: return i return -1