tweak check for batched prompt data
This commit is contained in:
@@ -21,6 +21,9 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
Bradley-Terry reward model pairwise chat template prompt strategy.
|
Bradley-Terry reward model pairwise chat template prompt strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||||
|
return all(isinstance(v, list) for v in prompt.values())
|
||||||
|
|
||||||
def _tokenize_single_prompt(self, prompt):
|
def _tokenize_single_prompt(self, prompt):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -226,12 +226,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
# Let calling code know we can handle lists of examples
|
# Let calling code know we can handle lists of examples
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||||
|
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||||
|
isinstance(v, list) for v in prompt[self.messages]
|
||||||
|
)
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt: dict[str, Any]):
|
def tokenize_prompt(self, prompt: dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Public method that can handle either a single prompt or a batch of prompts.
|
Public method that can handle either a single prompt or a batch of prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not all(isinstance(v, list) for v in prompt.values()):
|
if not self.is_prompt_batched(prompt):
|
||||||
return self._tokenize_single_prompt(prompt)
|
return self._tokenize_single_prompt(prompt)
|
||||||
|
|
||||||
res = defaultdict(lambda: [])
|
res = defaultdict(lambda: [])
|
||||||
|
|||||||
Reference in New Issue
Block a user