Handle other reasoning trace dataset formats (#2591)
* Handle other reasoning trace dataset formats * rename var to improve readability * chore: refactor with comments --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -34,7 +34,31 @@ def messages_w_reasoning_fixture():
|
||||
"content": "<think>lorem</think>\nwelcome",
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<|begin_of_thought|>lorem<|end_of_thought|>\n<|begin_of_solution|>welcome\n<|end_of_solution|>",
|
||||
},
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<reasoning>lorem</reasoning>\nwelcome",
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
@@ -83,36 +107,37 @@ class TestSplitThinking:
|
||||
}
|
||||
),
|
||||
)
|
||||
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
|
||||
assert transformed_prompt[0]["role"] == "user"
|
||||
assert transformed_prompt[1]["role"] == "assistant"
|
||||
assert transformed_prompt[1]["reasoning_content"] == "lorem"
|
||||
assert transformed_prompt[1]["content"] == "welcome"
|
||||
for conversation in messages_w_reasoning:
|
||||
transformed_prompt = strategy.get_conversation_thread(conversation)
|
||||
assert transformed_prompt[0]["role"] == "user"
|
||||
assert transformed_prompt[1]["role"] == "assistant"
|
||||
assert transformed_prompt[1]["reasoning_content"] == "lorem"
|
||||
assert transformed_prompt[1]["content"] == "welcome"
|
||||
|
||||
res = strategy.tokenize_prompt(messages_w_reasoning[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
151644, # im_start
|
||||
872, # user
|
||||
198, # \n
|
||||
14990, # hello
|
||||
151645, # im_end
|
||||
198, # \n
|
||||
151644, # im_start
|
||||
77091, # assistant
|
||||
198, # \n
|
||||
151667, # think
|
||||
198, # \n
|
||||
385, 1826, # lorem
|
||||
198, # \n
|
||||
151668, # /think
|
||||
271, # \n
|
||||
34084, # welcome
|
||||
151645, # im_end
|
||||
198, # \n
|
||||
]
|
||||
# fmt: on
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
res = strategy.tokenize_prompt(conversation)
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
expected_input_ids = [
|
||||
151644, # im_start
|
||||
872, # user
|
||||
198, # \n
|
||||
14990, # hello
|
||||
151645, # im_end
|
||||
198, # \n
|
||||
151644, # im_start
|
||||
77091, # assistant
|
||||
198, # \n
|
||||
151667, # think
|
||||
198, # \n
|
||||
385, 1826, # lorem
|
||||
198, # \n
|
||||
151668, # /think
|
||||
271, # \n
|
||||
34084, # welcome
|
||||
151645, # im_end
|
||||
198, # \n
|
||||
]
|
||||
# fmt: on
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
|
||||
Reference in New Issue
Block a user