pre-cache the eot token ids rather than on each iteration (#3594) [skip ci]
This commit is contained in:
@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
self._validate_eot_and_eos_tokens()
|
||||
|
||||
# Pre-cache EOT token IDs to avoid re-encoding on every call
|
||||
self._eot_token_ids = set()
|
||||
for token in self.eot_tokens:
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if len(token_ids) == 1:
|
||||
self._eot_token_ids.add(token_ids[0])
|
||||
|
||||
def _validate_eot_and_eos_tokens(self):
|
||||
"""
|
||||
- Validates that EOT tokens (or eos_token) are in the chat_template
|
||||
@@ -632,20 +639,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# Get token IDs for all EOT tokens
|
||||
eot_token_ids = []
|
||||
for token in self.eot_tokens:
|
||||
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
|
||||
if len(token_ids) != 1:
|
||||
raise ValueError(
|
||||
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
|
||||
)
|
||||
|
||||
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
|
||||
|
||||
# Search for any of the EOT token IDs
|
||||
# Use pre-cached EOT token IDs (computed once in __init__)
|
||||
for i in range(start_idx, len(input_ids)):
|
||||
if input_ids[i] in eot_token_ids:
|
||||
if input_ids[i] in self._eot_token_ids:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user