Handle other reasoning trace dataset formats (#2591)
* Handle other reasoning trace dataset formats * rename var to improve readability * chore: refactor with comments --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
"role": "role",
|
"role": "role",
|
||||||
"content": "content",
|
"content": "content",
|
||||||
|
"reasoning_content": "reasoning_content",
|
||||||
}
|
}
|
||||||
|
|
||||||
if roles:
|
if roles:
|
||||||
@@ -661,16 +662,46 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
# if the role is assistant that we want to use reasoning_content
|
# if the role is assistant that we want to use reasoning_content
|
||||||
if self.split_thinking and transformed_message["role"] == "assistant":
|
if self.split_thinking and transformed_message["role"] == "assistant":
|
||||||
content = transformed_message["content"]
|
content = transformed_message["content"]
|
||||||
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
|
thinking_pairs = [
|
||||||
for pair in pairs:
|
("<think>", "</think>"),
|
||||||
if pair[0] in content and pair[1] in content:
|
("<reasoning>", "</reasoning>"),
|
||||||
start_idx = content.find(pair[0])
|
("<|begin_of_thought|>", "<|end_of_thought|>"),
|
||||||
end_idx = content.find(pair[1])
|
]
|
||||||
thinking_content = content[start_idx + len(pair[0]) : end_idx]
|
content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
|
||||||
|
for tpair in thinking_pairs:
|
||||||
|
# check if the thinking pair is in the content
|
||||||
|
if tpair[0] in content and tpair[1] in content:
|
||||||
|
# find the start and end index of the thinking pair
|
||||||
|
t_start_idx = content.find(tpair[0])
|
||||||
|
t_end_idx = content.find(tpair[1])
|
||||||
|
|
||||||
|
# get the thinking content
|
||||||
|
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||||
transformed_message["content"] = content[
|
|
||||||
end_idx + len(pair[1]) :
|
# take remainder of the content
|
||||||
].lstrip()
|
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||||
|
remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
|
||||||
|
|
||||||
|
# check if the content pair is in the remainder
|
||||||
|
cpair_found = False
|
||||||
|
for cpair in content_pairs:
|
||||||
|
if cpair[0] in remainder and cpair[1] in remainder:
|
||||||
|
# find the start and end index of the content pair
|
||||||
|
c_start_idx = remainder.find(cpair[0])
|
||||||
|
c_end_idx = remainder.find(cpair[1])
|
||||||
|
|
||||||
|
# get the content content
|
||||||
|
content_content = remainder[
|
||||||
|
c_start_idx + len(cpair[0]) : c_end_idx
|
||||||
|
]
|
||||||
|
transformed_message["content"] = content_content.strip()
|
||||||
|
cpair_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# else, the content is the remainder
|
||||||
|
if not cpair_found:
|
||||||
|
transformed_message["content"] = remainder
|
||||||
break
|
break
|
||||||
|
|
||||||
# Determine which keys in the original message were not mapped
|
# Determine which keys in the original message were not mapped
|
||||||
|
|||||||
@@ -34,7 +34,31 @@ def messages_w_reasoning_fixture():
|
|||||||
"content": "<think>lorem</think>\nwelcome",
|
"content": "<think>lorem</think>\nwelcome",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "<reasoning>lorem</reasoning>\nwelcome",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,36 +107,37 @@ class TestSplitThinking:
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
|
for conversation in messages_w_reasoning:
|
||||||
assert transformed_prompt[0]["role"] == "user"
|
transformed_prompt = strategy.get_conversation_thread(conversation)
|
||||||
assert transformed_prompt[1]["role"] == "assistant"
|
assert transformed_prompt[0]["role"] == "user"
|
||||||
assert transformed_prompt[1]["reasoning_content"] == "lorem"
|
assert transformed_prompt[1]["role"] == "assistant"
|
||||||
assert transformed_prompt[1]["content"] == "welcome"
|
assert transformed_prompt[1]["reasoning_content"] == "lorem"
|
||||||
|
assert transformed_prompt[1]["content"] == "welcome"
|
||||||
|
|
||||||
res = strategy.tokenize_prompt(messages_w_reasoning[0])
|
res = strategy.tokenize_prompt(conversation)
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected_input_ids = [
|
expected_input_ids = [
|
||||||
151644, # im_start
|
151644, # im_start
|
||||||
872, # user
|
872, # user
|
||||||
198, # \n
|
198, # \n
|
||||||
14990, # hello
|
14990, # hello
|
||||||
151645, # im_end
|
151645, # im_end
|
||||||
198, # \n
|
198, # \n
|
||||||
151644, # im_start
|
151644, # im_start
|
||||||
77091, # assistant
|
77091, # assistant
|
||||||
198, # \n
|
198, # \n
|
||||||
151667, # think
|
151667, # think
|
||||||
198, # \n
|
198, # \n
|
||||||
385, 1826, # lorem
|
385, 1826, # lorem
|
||||||
198, # \n
|
198, # \n
|
||||||
151668, # /think
|
151668, # /think
|
||||||
271, # \n
|
271, # \n
|
||||||
34084, # welcome
|
34084, # welcome
|
||||||
151645, # im_end
|
151645, # im_end
|
||||||
198, # \n
|
198, # \n
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert (
|
assert (
|
||||||
input_ids == expected_input_ids
|
input_ids == expected_input_ids
|
||||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|||||||
Reference in New Issue
Block a user