fix reward trainer calls for tokenization

This commit is contained in:
Wing Lian
2025-01-07 15:41:40 -05:00
parent ec4dfb02c8
commit 74d98ca6d8

View File

@@ -39,7 +39,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
) )
prompt[self.messages].append({"role": "user", "content": prompt["input"]}) prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt) chosen_tokenized = super()._tokenize_single_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length: if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning( LOG.warning(
@@ -62,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append( prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]} {"role": "assistant", "content": prompt["rejected"]}
) )
rejected_tokenized = super().tokenize_prompt(prompt) rejected_tokenized = super()._tokenize_single_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length: if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning( LOG.warning(