diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 0a02f54e5..0bbb1de8b 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -21,6 +21,9 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): Bradley-Terry reward model pairwise chat template prompt strategy. """ + def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: + return all(isinstance(v, list) for v in prompt.values()) + def _tokenize_single_prompt(self, prompt): """ diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index f77dd30d9..051460427 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -226,12 +226,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): # Let calling code know we can handle lists of examples return True + def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: + return all(isinstance(v, list) for v in prompt.values()) and all( + isinstance(v, list) for v in prompt[self.messages] + ) + def tokenize_prompt(self, prompt: dict[str, Any]): """ Public method that can handle either a single prompt or a batch of prompts. """ - if not all(isinstance(v, list) for v in prompt.values()): + if not self.is_prompt_batched(prompt): return self._tokenize_single_prompt(prompt) res = defaultdict(lambda: [])