From 2045ff2b7aef43b9f39ce79688b1066766165ebe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 7 Jan 2025 14:54:32 -0500 Subject: [PATCH] tweak check for batched prompt data --- .../prompt_strategies/bradley_terry/chat_template.py | 3 +++ src/axolotl/prompt_strategies/chat_template.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 0a02f54e5..0bbb1de8b 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -21,6 +21,9 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): Bradley-Terry reward model pairwise chat template prompt strategy. """ + def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: + return all(isinstance(v, list) for v in prompt.values()) + def _tokenize_single_prompt(self, prompt): """ diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index f77dd30d9..051460427 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -226,12 +226,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): # Let calling code know we can handle lists of examples return True + def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: + return all(isinstance(v, list) for v in prompt.values()) and all( + isinstance(v, list) for v in prompt[self.messages] + ) + def tokenize_prompt(self, prompt: dict[str, Any]): """ Public method that can handle either a single prompt or a batch of prompts. """ - if not all(isinstance(v, list) for v in prompt.values()): + if not self.is_prompt_batched(prompt): return self._tokenize_single_prompt(prompt) res = defaultdict(lambda: [])