From 74d98ca6d87784553e6e8cc4d8bd5e771aeac100 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 7 Jan 2025 15:41:40 -0500 Subject: [PATCH] fix reward trainer calls for tokenization --- src/axolotl/prompt_strategies/bradley_terry/chat_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 0a02f54e5..627466287 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -39,7 +39,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): ) prompt[self.messages].append({"role": "user", "content": prompt["input"]}) prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) - chosen_tokenized = super().tokenize_prompt(prompt) + chosen_tokenized = super()._tokenize_single_prompt(prompt) if len(chosen_tokenized["input_ids"]) > max_length: LOG.warning( @@ -62,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): prompt[self.messages].append( {"role": "assistant", "content": prompt["rejected"]} ) - rejected_tokenized = super().tokenize_prompt(prompt) + rejected_tokenized = super()._tokenize_single_prompt(prompt) if len(rejected_tokenized["input_ids"]) > max_length: LOG.warning(