diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index c2948fc11..638cee559 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter): message_property_mappings = { "role": "role", "content": "content", + "reasoning_content": "reasoning_content", } if roles: @@ -661,16 +662,46 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): # if the role is assistant that we want to use reasoning_content if self.split_thinking and transformed_message["role"] == "assistant": content = transformed_message["content"] - pairs = [("", ""), ("", "")] - for pair in pairs: - if pair[0] in content and pair[1] in content: - start_idx = content.find(pair[0]) - end_idx = content.find(pair[1]) - thinking_content = content[start_idx + len(pair[0]) : end_idx] + thinking_pairs = [ + ("", ""), + ("", ""), + ("<|begin_of_thought|>", "<|end_of_thought|>"), + ] + content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")] + for tpair in thinking_pairs: + # check if the thinking pair is in the content + if tpair[0] in content and tpair[1] in content: + # find the start and end index of the thinking pair + t_start_idx = content.find(tpair[0]) + t_end_idx = content.find(tpair[1]) + + # get the thinking content + thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx] transformed_message["reasoning_content"] = thinking_content.strip() - transformed_message["content"] = content[ - end_idx + len(pair[1]) : - ].lstrip() + + # take remainder of the content + # strip whitespace from beginning of the remainder (thinking tokens) + remainder = content[t_end_idx + len(tpair[1]) :].lstrip() + + # check if the content pair is in the remainder + cpair_found = False + for cpair in content_pairs: + if cpair[0] in remainder and cpair[1] in remainder: + # find the start and end index of the content pair + c_start_idx = remainder.find(cpair[0]) + c_end_idx = remainder.find(cpair[1]) + + # get the content content + content_content = remainder[ + c_start_idx + len(cpair[0]) : c_end_idx + ] + transformed_message["content"] = content_content.strip() + cpair_found = True + break + + # else, the content is the remainder + if not cpair_found: + transformed_message["content"] = remainder break # Determine which keys in the original message were not mapped diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 236a7ec39..9fe292317 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -34,7 +34,31 @@ def messages_w_reasoning_fixture(): "content": "lorem\nwelcome", }, ] - } + }, + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>", + }, + ] + }, + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "lorem\nwelcome", + }, + ] + }, ] ) @@ -83,36 +107,37 @@ class TestSplitThinking: } ), ) - transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0]) - assert transformed_prompt[0]["role"] == "user" - assert transformed_prompt[1]["role"] == "assistant" - assert transformed_prompt[1]["reasoning_content"] == "lorem" - assert transformed_prompt[1]["content"] == "welcome" + for conversation in messages_w_reasoning: + transformed_prompt = strategy.get_conversation_thread(conversation) + assert transformed_prompt[0]["role"] == "user" + assert transformed_prompt[1]["role"] == "assistant" + assert transformed_prompt[1]["reasoning_content"] == "lorem" + assert transformed_prompt[1]["content"] == "welcome" - res = strategy.tokenize_prompt(messages_w_reasoning[0]) - input_ids = res["input_ids"] - # fmt: off - expected_input_ids = [ - 151644, # im_start - 872, # user - 198, # \n - 14990, # hello - 151645, # im_end - 198, # \n - 151644, # im_start - 77091, # assistant - 198, # \n - 151667, # think - 198, # \n - 385, 1826, # lorem - 198, # \n - 151668, # /think - 271, # \n - 34084, # welcome - 151645, # im_end - 198, # \n - ] - # fmt: on - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + res = strategy.tokenize_prompt(conversation) + input_ids = res["input_ids"] + # fmt: off + expected_input_ids = [ + 151644, # im_start + 872, # user + 198, # \n + 14990, # hello + 151645, # im_end + 198, # \n + 151644, # im_start + 77091, # assistant + 198, # \n + 151667, # think + 198, # \n + 385, 1826, # lorem + 198, # \n + 151668, # /think + 271, # \n + 34084, # welcome + 151645, # im_end + 198, # \n + ] + # fmt: on + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"