diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 076ddac1f..399bb378a 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -228,6 +228,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): train_on_eos: Optional[str] = None, train_on_eot: Optional[str] = None, eot_tokens: Optional[List[str]] = None, + split_thinking: Optional[bool] = False, ): super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) self.prompter: ChatTemplatePrompter = prompter @@ -247,6 +248,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self.eot_tokens = ( eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token] ) + self.split_thinking = split_thinking self.images = "images" @@ -655,6 +657,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): transformed_message["role"], transformed_message["role"] ) + # TODO handle reasoning_content with split_thinking + # if the role is assistant that we want to use reasoning_content + if self.split_thinking and transformed_message["role"] == "assistant": + content = transformed_message["content"] + pairs = [("", ""), ("", "")] + for pair in pairs: + if pair[0] in content and pair[1] in content: + start_idx = content.find(pair[0]) + end_idx = content.find(pair[1]) + thinking_content = content[start_idx + len(pair[0]) : end_idx] + transformed_message["reasoning_content"] = thinking_content.strip() + transformed_message["content"] = content[ + end_idx + len(pair[1]) : + ].lstrip() + break + # Determine which keys in the original message were not mapped mapped_values = set(self.prompter.message_property_mappings.values()) remaining_keys = set(message) - mapped_values @@ -689,6 +707,7 @@ class StrategyLoader: "train_on_eos": ds_cfg.get("train_on_eos", "turn"), "train_on_eot": ds_cfg.get("train_on_eot", None), "eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg + "split_thinking": ds_cfg.get("split_thinking", False), } def __call__( diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index f9b694da1..cc5d6daba 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -50,6 +50,7 @@ class SFTDataset(BaseModel): message_property_mappings: dict[str, str] | None = None message_field_training: str | None = None message_field_training_detail: str | None = None + split_thinking: bool | None = None logprobs_field: str | None = None temperature: float | None = None roles_to_train: list[str] | None = None diff --git a/tests/conftest.py b/tests/conftest.py index 82e4e911b..7fc9a62af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,6 +90,12 @@ def download_qwen_2_5_half_billion_model(): snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model") +@pytest.fixture(scope="session", autouse=True) +def download_qwen3_half_billion_model(): + # download the model + snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model") + + @pytest.fixture(scope="session", autouse=True) def download_tatsu_lab_alpaca_dataset(): # download the dataset diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py new file mode 100644 index 000000000..236a7ec39 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -0,0 +1,118 @@ +""" +Tests for splitting reasoning/thinking from content into separate field +""" + +import logging + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.prompt_strategies.chat_template import ( + load, +) +from axolotl.utils.dict import DictDefault + +from tests.hf_offline_utils import enable_hf_offline + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger("axolotl") + + +@pytest.fixture(name="messages_w_reasoning") +def messages_w_reasoning_fixture(): + return Dataset.from_list( + [ + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "lorem\nwelcome", + }, + ] + } + ] + ) + + +@pytest.fixture(name="qwen3_tokenizer") +@enable_hf_offline +def qwen3_tokenizer_fixture( + download_qwen3_half_billion_model, +): # pylint: disable=unused-argument + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + return tokenizer + + +class TestSplitThinking: + """ + test class to make sure datasets with reasoning content conforms to the chat_template strategy + """ + + def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer): + # pylint: disable=duplicate-code + strategy = load( + qwen3_tokenizer, + DictDefault( + { + "train_on_inputs": False, + "sequence_len": 512, + } + ), + DictDefault( + { + "chat_template": "qwen3", + "message_field_role": "role", + "message_field_content": "content", + "message_property_mappings": { + "role": "role", + "content": "content", + }, + "roles": { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + "field_messages": "messages", + "split_thinking": True, + } + ), + ) + transformed_prompt = strategy.get_conversation_thread(messages_w_reasoning[0]) + assert transformed_prompt[0]["role"] == "user" + assert transformed_prompt[1]["role"] == "assistant" + assert transformed_prompt[1]["reasoning_content"] == "lorem" + assert transformed_prompt[1]["content"] == "welcome" + + res = strategy.tokenize_prompt(messages_w_reasoning[0]) + input_ids = res["input_ids"] + # fmt: off + expected_input_ids = [ + 151644, # im_start + 872, # user + 198, # \n + 14990, # hello + 151645, # im_end + 198, # \n + 151644, # im_start + 77091, # assistant + 198, # \n + 151667, # think + 198, # \n + 385, 1826, # lorem + 198, # \n + 151668, # /think + 271, # \n + 34084, # welcome + 151645, # im_end + 198, # \n + ] + # fmt: on + assert ( + input_ids == expected_input_ids + ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"