automatically split out reasoning trace from dataset (#2579)
* automatically split out reasoning trace from dataset * chore: lint * fix import
This commit is contained in:
@@ -228,6 +228,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
train_on_eos: Optional[str] = None,
|
train_on_eos: Optional[str] = None,
|
||||||
train_on_eot: Optional[str] = None,
|
train_on_eot: Optional[str] = None,
|
||||||
eot_tokens: Optional[List[str]] = None,
|
eot_tokens: Optional[List[str]] = None,
|
||||||
|
split_thinking: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
self.prompter: ChatTemplatePrompter = prompter
|
self.prompter: ChatTemplatePrompter = prompter
|
||||||
@@ -247,6 +248,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
self.eot_tokens = (
|
self.eot_tokens = (
|
||||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||||
)
|
)
|
||||||
|
self.split_thinking = split_thinking
|
||||||
|
|
||||||
self.images = "images"
|
self.images = "images"
|
||||||
|
|
||||||
@@ -655,6 +657,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
transformed_message["role"], transformed_message["role"]
|
transformed_message["role"], transformed_message["role"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO handle reasoning_content with split_thinking
|
||||||
|
# if the role is assistant that we want to use reasoning_content
|
||||||
|
if self.split_thinking and transformed_message["role"] == "assistant":
|
||||||
|
content = transformed_message["content"]
|
||||||
|
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
|
||||||
|
for pair in pairs:
|
||||||
|
if pair[0] in content and pair[1] in content:
|
||||||
|
start_idx = content.find(pair[0])
|
||||||
|
end_idx = content.find(pair[1])
|
||||||
|
thinking_content = content[start_idx + len(pair[0]) : end_idx]
|
||||||
|
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||||
|
transformed_message["content"] = content[
|
||||||
|
end_idx + len(pair[1]) :
|
||||||
|
].lstrip()
|
||||||
|
break
|
||||||
|
|
||||||
# Determine which keys in the original message were not mapped
|
# Determine which keys in the original message were not mapped
|
||||||
mapped_values = set(self.prompter.message_property_mappings.values())
|
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||||
remaining_keys = set(message) - mapped_values
|
remaining_keys = set(message) - mapped_values
|
||||||
@@ -689,6 +707,7 @@ class StrategyLoader:
|
|||||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||||
"train_on_eot": ds_cfg.get("train_on_eot", None),
|
"train_on_eot": ds_cfg.get("train_on_eot", None),
|
||||||
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
|
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
|
||||||
|
"split_thinking": ds_cfg.get("split_thinking", False),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class SFTDataset(BaseModel):
|
|||||||
message_property_mappings: dict[str, str] | None = None
|
message_property_mappings: dict[str, str] | None = None
|
||||||
message_field_training: str | None = None
|
message_field_training: str | None = None
|
||||||
message_field_training_detail: str | None = None
|
message_field_training_detail: str | None = None
|
||||||
|
split_thinking: bool | None = None
|
||||||
logprobs_field: str | None = None
|
logprobs_field: str | None = None
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
roles_to_train: list[str] | None = None
|
roles_to_train: list[str] | None = None
|
||||||
|
|||||||
@@ -90,6 +90,12 @@ def download_qwen_2_5_half_billion_model():
|
|||||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_qwen3_half_billion_model():
|
||||||
|
# download the model
|
||||||
|
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_tatsu_lab_alpaca_dataset():
|
def download_tatsu_lab_alpaca_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
|
|||||||
118
tests/prompt_strategies/test_chat_templates_thinking.py
Normal file
118
tests/prompt_strategies/test_chat_templates_thinking.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
Tests for splitting reasoning/thinking from content into separate field
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
|
load,
|
||||||
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="messages_w_reasoning")
|
||||||
|
def messages_w_reasoning_fixture():
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "<think>lorem</think>\nwelcome",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="qwen3_tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
|
def qwen3_tokenizer_fixture(
|
||||||
|
download_qwen3_half_billion_model,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitThinking:
|
||||||
|
"""
|
||||||
|
test class to make sure datasets with reasoning content conforms to the chat_template strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
strategy = load(
|
||||||
|
qwen3_tokenizer,
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"train_on_inputs": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "qwen3",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
|
"roles": {
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
"field_messages": "messages",
|
||||||
|
"split_thinking": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
|
||||||
|
assert transformed_prompt[0]["role"] == "user"
|
||||||
|
assert transformed_prompt[1]["role"] == "assistant"
|
||||||
|
assert transformed_prompt[1]["reasoning_content"] == "lorem"
|
||||||
|
assert transformed_prompt[1]["content"] == "welcome"
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(messages_w_reasoning[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
151644, # im_start
|
||||||
|
872, # user
|
||||||
|
198, # \n
|
||||||
|
14990, # hello
|
||||||
|
151645, # im_end
|
||||||
|
198, # \n
|
||||||
|
151644, # im_start
|
||||||
|
77091, # assistant
|
||||||
|
198, # \n
|
||||||
|
151667, # think
|
||||||
|
198, # \n
|
||||||
|
385, 1826, # lorem
|
||||||
|
198, # \n
|
||||||
|
151668, # /think
|
||||||
|
271, # \n
|
||||||
|
34084, # welcome
|
||||||
|
151645, # im_end
|
||||||
|
198, # \n
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
Reference in New Issue
Block a user