Handle other reasoning trace dataset formats (#2591)

* Handle other reasoning trace dataset formats

* rename var to improve readability

* chore: refactor with comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-04-30 03:32:55 -04:00
committed by GitHub
parent 2413688b08
commit baeb00231b
2 changed files with 98 additions and 42 deletions

View File

@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = { message_property_mappings = {
"role": "role", "role": "role",
"content": "content", "content": "content",
"reasoning_content": "reasoning_content",
} }
if roles: if roles:
@@ -661,16 +662,46 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# if the role is assistant that we want to use reasoning_content # if the role is assistant that we want to use reasoning_content
if self.split_thinking and transformed_message["role"] == "assistant": if self.split_thinking and transformed_message["role"] == "assistant":
content = transformed_message["content"] content = transformed_message["content"]
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")] thinking_pairs = [
for pair in pairs: ("<think>", "</think>"),
if pair[0] in content and pair[1] in content: ("<reasoning>", "</reasoning>"),
start_idx = content.find(pair[0]) ("<|begin_of_thought|>", "<|end_of_thought|>"),
end_idx = content.find(pair[1]) ]
thinking_content = content[start_idx + len(pair[0]) : end_idx] content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
for tpair in thinking_pairs:
# check if the thinking pair is in the content
if tpair[0] in content and tpair[1] in content:
# find the start and end index of the thinking pair
t_start_idx = content.find(tpair[0])
t_end_idx = content.find(tpair[1])
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message["reasoning_content"] = thinking_content.strip() transformed_message["reasoning_content"] = thinking_content.strip()
transformed_message["content"] = content[
end_idx + len(pair[1]) : # take remainder of the content
].lstrip() # strip whitespace from beginning of the remainder (thinking tokens)
remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
# check if the content pair is in the remainder
cpair_found = False
for cpair in content_pairs:
if cpair[0] in remainder and cpair[1] in remainder:
# find the start and end index of the content pair
c_start_idx = remainder.find(cpair[0])
c_end_idx = remainder.find(cpair[1])
# get the content content
content_content = remainder[
c_start_idx + len(cpair[0]) : c_end_idx
]
transformed_message["content"] = content_content.strip()
cpair_found = True
break
# else, the content is the remainder
if not cpair_found:
transformed_message["content"] = remainder
break break
# Determine which keys in the original message were not mapped # Determine which keys in the original message were not mapped

View File

@@ -34,7 +34,31 @@ def messages_w_reasoning_fixture():
"content": "<think>lorem</think>\nwelcome", "content": "<think>lorem</think>\nwelcome",
}, },
] ]
} },
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>",
},
]
},
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "<reasoning>lorem</reasoning>\nwelcome",
},
]
},
] ]
) )
@@ -83,36 +107,37 @@ class TestSplitThinking:
} }
), ),
) )
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0]) for conversation in messages_w_reasoning:
assert transformed_prompt[0]["role"] == "user" transformed_prompt = strategy.get_conversation_thread(conversation)
assert transformed_prompt[1]["role"] == "assistant" assert transformed_prompt[0]["role"] == "user"
assert transformed_prompt[1]["reasoning_content"] == "lorem" assert transformed_prompt[1]["role"] == "assistant"
assert transformed_prompt[1]["content"] == "welcome" assert transformed_prompt[1]["reasoning_content"] == "lorem"
assert transformed_prompt[1]["content"] == "welcome"
res = strategy.tokenize_prompt(messages_w_reasoning[0]) res = strategy.tokenize_prompt(conversation)
input_ids = res["input_ids"] input_ids = res["input_ids"]
# fmt: off # fmt: off
expected_input_ids = [ expected_input_ids = [
151644, # im_start 151644, # im_start
872, # user 872, # user
198, # \n 198, # \n
14990, # hello 14990, # hello
151645, # im_end 151645, # im_end
198, # \n 198, # \n
151644, # im_start 151644, # im_start
77091, # assistant 77091, # assistant
198, # \n 198, # \n
151667, # think 151667, # think
198, # \n 198, # \n
385, 1826, # lorem 385, 1826, # lorem
198, # \n 198, # \n
151668, # /think 151668, # /think
271, # \n 271, # \n
34084, # welcome 34084, # welcome
151645, # im_end 151645, # im_end
198, # \n 198, # \n
] ]
# fmt: on # fmt: on
assert ( assert (
input_ids == expected_input_ids input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"