tweak check for batched prompt data
This commit is contained in:
@@ -21,6 +21,9 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
Bradley-Terry reward model pairwise chat template prompt strategy.
|
||||
"""
|
||||
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
return all(isinstance(v, list) for v in prompt.values())
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
"""
|
||||
|
||||
|
||||
@@ -226,12 +226,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
# Let calling code know we can handle lists of examples
|
||||
return True
|
||||
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.messages]
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt: dict[str, Any]):
|
||||
"""
|
||||
Public method that can handle either a single prompt or a batch of prompts.
|
||||
"""
|
||||
|
||||
if not all(isinstance(v, list) for v in prompt.values()):
|
||||
if not self.is_prompt_batched(prompt):
|
||||
return self._tokenize_single_prompt(prompt)
|
||||
|
||||
res = defaultdict(lambda: [])
|
||||
|
||||
Reference in New Issue
Block a user