diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index 076ddac1f..399bb378a 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -228,6 +228,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos: Optional[str] = None,
train_on_eot: Optional[str] = None,
eot_tokens: Optional[List[str]] = None,
+ split_thinking: Optional[bool] = False,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
@@ -247,6 +248,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self.eot_tokens = (
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
)
+ self.split_thinking = split_thinking
self.images = "images"
@@ -655,6 +657,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
transformed_message["role"], transformed_message["role"]
)
+ # TODO handle reasoning_content with split_thinking
+ # if the role is assistant that we want to use reasoning_content
+ if self.split_thinking and transformed_message["role"] == "assistant":
+ content = transformed_message["content"]
+ pairs = [("", ""), ("", "")]
+ for pair in pairs:
+ if pair[0] in content and pair[1] in content:
+ start_idx = content.find(pair[0])
+ end_idx = content.find(pair[1])
+ thinking_content = content[start_idx + len(pair[0]) : end_idx]
+ transformed_message["reasoning_content"] = thinking_content.strip()
+ transformed_message["content"] = content[
+ end_idx + len(pair[1]) :
+ ].lstrip()
+ break
+
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
@@ -689,6 +707,7 @@ class StrategyLoader:
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"train_on_eot": ds_cfg.get("train_on_eot", None),
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
+ "split_thinking": ds_cfg.get("split_thinking", False),
}
def __call__(
diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py
index f9b694da1..cc5d6daba 100644
--- a/src/axolotl/utils/schemas/datasets.py
+++ b/src/axolotl/utils/schemas/datasets.py
@@ -50,6 +50,7 @@ class SFTDataset(BaseModel):
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
+ split_thinking: bool | None = None
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None
diff --git a/tests/conftest.py b/tests/conftest.py
index 82e4e911b..7fc9a62af 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -90,6 +90,12 @@ def download_qwen_2_5_half_billion_model():
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
+@pytest.fixture(scope="session", autouse=True)
+def download_qwen3_half_billion_model():
+ # download the model
+ snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
+
+
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py
new file mode 100644
index 000000000..236a7ec39
--- /dev/null
+++ b/tests/prompt_strategies/test_chat_templates_thinking.py
@@ -0,0 +1,118 @@
+"""
+Tests for splitting reasoning/thinking from content into separate field
+"""
+
+import logging
+
+import pytest
+from datasets import Dataset
+from transformers import AutoTokenizer
+
+from axolotl.prompt_strategies.chat_template import (
+ load,
+)
+from axolotl.utils.dict import DictDefault
+
+from tests.hf_offline_utils import enable_hf_offline
+
+logging.basicConfig(level=logging.DEBUG)
+LOG = logging.getLogger("axolotl")
+
+
+@pytest.fixture(name="messages_w_reasoning")
+def messages_w_reasoning_fixture():
+ return Dataset.from_list(
+ [
+ {
+ "messages": [
+ {
+ "role": "user",
+ "content": "hello",
+ },
+ {
+ "role": "assistant",
+ "content": "lorem\nwelcome",
+ },
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="qwen3_tokenizer")
+@enable_hf_offline
+def qwen3_tokenizer_fixture(
+ download_qwen3_half_billion_model,
+): # pylint: disable=unused-argument
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
+
+ return tokenizer
+
+
+class TestSplitThinking:
+ """
+ test class to make sure datasets with reasoning content conforms to the chat_template strategy
+ """
+
+ def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
+ # pylint: disable=duplicate-code
+ strategy = load(
+ qwen3_tokenizer,
+ DictDefault(
+ {
+ "train_on_inputs": False,
+ "sequence_len": 512,
+ }
+ ),
+ DictDefault(
+ {
+ "chat_template": "qwen3",
+ "message_field_role": "role",
+ "message_field_content": "content",
+ "message_property_mappings": {
+ "role": "role",
+ "content": "content",
+ },
+ "roles": {
+ "user": ["user"],
+ "assistant": ["assistant"],
+ "system": ["system"],
+ },
+ "field_messages": "messages",
+ "split_thinking": True,
+ }
+ ),
+ )
+ transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0])
+ assert transformed_prompt[0]["role"] == "user"
+ assert transformed_prompt[1]["role"] == "assistant"
+ assert transformed_prompt[1]["reasoning_content"] == "lorem"
+ assert transformed_prompt[1]["content"] == "welcome"
+
+ res = strategy.tokenize_prompt(messages_w_reasoning[0])
+ input_ids = res["input_ids"]
+ # fmt: off
+ expected_input_ids = [
+ 151644, # im_start
+ 872, # user
+ 198, # \n
+ 14990, # hello
+ 151645, # im_end
+ 198, # \n
+ 151644, # im_start
+ 77091, # assistant
+ 198, # \n
+ 151667, # think
+ 198, # \n
+ 385, 1826, # lorem
+ 198, # \n
+ 151668, # /think
+ 271, # \n
+ 34084, # welcome
+ 151645, # im_end
+ 198, # \n
+ ]
+ # fmt: on
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"