pre-cache the eot token ids rather than on each iteration (#3594) [skip ci]

This commit is contained in:
Wing Lian
2026-04-11 20:05:21 -04:00
committed by GitHub
parent e77a185e86
commit 122b50bad6

View File

@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._validate_eot_and_eos_tokens()
# Pre-cache EOT token IDs to avoid re-encoding on every call
self._eot_token_ids = set()
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) == 1:
self._eot_token_ids.add(token_ids[0])
def _validate_eot_and_eos_tokens(self):
"""
- Validates that EOT tokens (or eos_token) are in the chat_template
@@ -632,20 +639,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# Get token IDs for all EOT tokens
eot_token_ids = []
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) != 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
)
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
# Search for any of the EOT token IDs
# Use pre-cached EOT token IDs (computed once in __init__)
for i in range(start_idx, len(input_ids)):
if input_ids[i] in eot_token_ids:
if input_ids[i] in self._eot_token_ids:
return i
return -1