tweak check for batched prompt data

This commit is contained in:
Wing Lian
2025-01-07 14:54:32 -05:00
parent 93903f4aa5
commit 2045ff2b7a
2 changed files with 9 additions and 1 deletions

View File

@@ -21,6 +21,9 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
return all(isinstance(v, list) for v in prompt.values())
def _tokenize_single_prompt(self, prompt):
"""

View File

@@ -226,12 +226,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# Let calling code know we can handle lists of examples
return True
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
)
def tokenize_prompt(self, prompt: dict[str, Any]):
"""
Public method that can handle either a single prompt or a batch of prompts.
"""
if not all(isinstance(v, list) for v in prompt.values()):
if not self.is_prompt_batched(prompt):
return self._tokenize_single_prompt(prompt)
res = defaultdict(lambda: [])