improve check for batched

This commit is contained in:
Wing Lian
2025-01-07 16:57:47 -05:00
parent 74d98ca6d8
commit 27bb21c459

View File

@@ -227,9 +227,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return True
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
)
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
)
except KeyError:
return False
def tokenize_prompt(self, prompt: dict[str, Any]):
"""