diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index c2948fc11..638cee559 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
+ "reasoning_content": "reasoning_content",
}
if roles:
@@ -661,16 +662,46 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# if the role is assistant that we want to use reasoning_content
if self.split_thinking and transformed_message["role"] == "assistant":
content = transformed_message["content"]
- pairs = [("", ""), ("", "")]
- for pair in pairs:
- if pair[0] in content and pair[1] in content:
- start_idx = content.find(pair[0])
- end_idx = content.find(pair[1])
- thinking_content = content[start_idx + len(pair[0]) : end_idx]
+ thinking_pairs = [
+ ("", ""),
+ ("", ""),
+ ("<|begin_of_thought|>", "<|end_of_thought|>"),
+ ]
+ content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
+ for tpair in thinking_pairs:
+ # check if the thinking pair is in the content
+ if tpair[0] in content and tpair[1] in content:
+ # find the start and end index of the thinking pair
+ t_start_idx = content.find(tpair[0])
+ t_end_idx = content.find(tpair[1])
+
+ # get the thinking content
+ thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message["reasoning_content"] = thinking_content.strip()
- transformed_message["content"] = content[
- end_idx + len(pair[1]) :
- ].lstrip()
+
+ # take remainder of the content
+ # strip whitespace from beginning of the remainder (thinking tokens)
+ remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
+
+ # check if the content pair is in the remainder
+ cpair_found = False
+ for cpair in content_pairs:
+ if cpair[0] in remainder and cpair[1] in remainder:
+ # find the start and end index of the content pair
+ c_start_idx = remainder.find(cpair[0])
+ c_end_idx = remainder.find(cpair[1])
+
+ # get the content content
+ content_content = remainder[
+ c_start_idx + len(cpair[0]) : c_end_idx
+ ]
+ transformed_message["content"] = content_content.strip()
+ cpair_found = True
+ break
+
+ # else, the content is the remainder
+ if not cpair_found:
+ transformed_message["content"] = remainder
break
# Determine which keys in the original message were not mapped
diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py
index 236a7ec39..9fe292317 100644
--- a/tests/prompt_strategies/test_chat_templates_thinking.py
+++ b/tests/prompt_strategies/test_chat_templates_thinking.py
@@ -34,7 +34,31 @@ def messages_w_reasoning_fixture():
"content": "lorem\nwelcome",
},
]
- }
+ },
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello",
+ },
+ {
+ "role": "assistant",
+ "content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>",
+ },
+ ]
+ },
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello",
+ },
+ {
+ "role": "assistant",
+ "content": "lorem\nwelcome",
+ },
+ ]
+ },
]
)
@@ -83,36 +107,37 @@ class TestSplitThinking:
}
),
)
- transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
- assert transformed_prompt[0]["role"] == "user"
- assert transformed_prompt[1]["role"] == "assistant"
- assert transformed_prompt[1]["reasoning_content"] == "lorem"
- assert transformed_prompt[1]["content"] == "welcome"
+ for conversation in messages_w_reasoning:
+ transformed_prompt = strategy.get_conversation_thread(conversation)
+ assert transformed_prompt[0]["role"] == "user"
+ assert transformed_prompt[1]["role"] == "assistant"
+ assert transformed_prompt[1]["reasoning_content"] == "lorem"
+ assert transformed_prompt[1]["content"] == "welcome"
- res = strategy.tokenize_prompt(messages_w_reasoning[0])
- input_ids = res["input_ids"]
- # fmt: off
- expected_input_ids = [
- 151644, # im_start
- 872, # user
- 198, # \n
- 14990, # hello
- 151645, # im_end
- 198, # \n
- 151644, # im_start
- 77091, # assistant
- 198, # \n
- 151667, # think
- 198, # \n
- 385, 1826, # lorem
- 198, # \n
- 151668, # /think
- 271, # \n
- 34084, # welcome
- 151645, # im_end
- 198, # \n
- ]
- # fmt: on
- assert (
- input_ids == expected_input_ids
- ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+ res = strategy.tokenize_prompt(conversation)
+ input_ids = res["input_ids"]
+ # fmt: off
+ expected_input_ids = [
+ 151644, # im_start
+ 872, # user
+ 198, # \n
+ 14990, # hello
+ 151645, # im_end
+ 198, # \n
+ 151644, # im_start
+ 77091, # assistant
+ 198, # \n
+ 151667, # think
+ 198, # \n
+ 385, 1826, # lorem
+ 198, # \n
+ 151668, # /think
+ 271, # \n
+ 34084, # welcome
+ 151645, # im_end
+ 198, # \n
+ ]
+ # fmt: on
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"