From 315cdeede976d604d5da17e91d1339c5a20314ec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 10 Apr 2026 14:11:10 -0400 Subject: [PATCH] handle trainable/masked spans in content and reasoning content (#3592) --- docs/dataset-formats/conversation.qmd | 107 ++++ .../prompt_strategies/chat_template.py | 194 ++++++- .../test_chat_templates_advanced.py | 476 ++++++++++++++++++ 3 files changed, 767 insertions(+), 10 deletions(-) diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 9d018dc49..adbb88645 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -302,6 +302,113 @@ datasets: It is not necessary to set both `message_field_training` and `message_field_training_detail` at once. ::: +#### Content parts with per-part training control + +Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer). + +```{.json filename="data.jsonl"} +{ + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think step by step...", "train": false}, + {"type": "text", "text": " The answer is 4.", "train": true} + ] + } + ] +} +``` + +The configuration is the same as standard `chat_template` — no extra fields needed: + +```yaml +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +Each content part supports: + +- `type`: `"text"` (required) +- `text`: the text value (also accepts `content` or `value` as the key) +- `train`: `true`/`false` (optional) — whether to train on this part +- `weight`: `0`/`1` (optional) — alternative to `train` + +If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`). + +::: {.callout-warning title="Whitespace at part boundaries"} +BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**: + +**Split BEFORE spaces** (space goes with the next part): + +```json +[ + {"type": "text", "text": "Let me think...", "train": false}, + {"type": "text", "text": " The answer is 4.", "train": true} +] +``` + +**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked): + +```json +[ + {"type": "text", "text": "Let me think... ", "train": false}, + {"type": "text", "text": "The answer is 4.", "train": true} +] +``` + +In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`. + +**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part: + +```json +[ + {"type": "text", "text": "Thinking:\n", "train": false}, + {"type": "text", "text": "The answer is 4.", "train": true} +] +``` + +Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags. +::: + +::: {.callout-note} +When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization. +::: + +##### Per-part training on reasoning_content + +For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections: + +```{.json filename="data.jsonl"} +{ + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]}, + { + "role": "assistant", + "reasoning_content": [ + {"type": "text", "text": "Hmm maybe 2+2=5.", "train": false}, + {"type": "text", "text": " Wait no, 2+2=4.", "train": true} + ], + "content": [ + {"type": "text", "text": "The answer is 4.", "train": true} + ] + } + ] +} +``` + +The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires. + +::: {.callout-tip} +When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data. +::: + +The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part. + + #### Reasoning split (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template. diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 57d3bfdf2..a7f749f3b 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -471,6 +471,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): content = turn.get("content") train_turn = turn.get("training") train_detail = turn.get("training_detail") + reasoning_train_detail = turn.get("reasoning_training_detail") LOG.debug( f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" @@ -479,8 +480,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): should_train = None if train_turn is not None: should_train = train_turn - elif train_detail is not None: - should_train = bool(train_detail) + elif train_detail is not None or reasoning_train_detail is not None: + should_train = bool(train_detail) or bool(reasoning_train_detail) else: should_train = self.train_on_inputs or role in self.roles_to_train @@ -500,15 +501,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): continue + thinking_key = self.prompter.template_thinking_key + has_reasoning = thinking_key and turn.get(thinking_key) is not None + has_any_detail = train_detail or reasoning_train_detail + + # When train_detail is present and the turn has reasoning_content, + # use content_only=True so find_turn returns content-only boundaries + # (excluding reasoning_content + template separator tokens). + use_content_only = bool(has_any_detail and has_reasoning) + turn_start_idx, turn_end_idx = self.find_turn( - turns=turns, turn_idx=index, tools=tools + turns=turns, + turn_idx=index, + tools=tools, + content_only=use_content_only, ) LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") if should_train and turn_start_idx != -1 and turn_end_idx != -1: if train_detail: - # Block multi-content for now if not isinstance(content, str): raise ValueError( "`train_detail` is not supported when `content` is not a string." @@ -526,7 +538,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): LOG.debug( f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}" ) - else: + elif not reasoning_train_detail: + # No per-part detail on either field — train the whole span labels[turn_start_idx:turn_end_idx] = input_ids[ turn_start_idx:turn_end_idx ] @@ -534,6 +547,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): f"Set labels for training from {turn_start_idx} to {turn_end_idx}" ) + # Handle reasoning_content training_detail separately + if should_train and reasoning_train_detail and has_reasoning: + reasoning_text = turn[thinking_key] + if not isinstance(reasoning_text, str): + raise ValueError( + "`reasoning_training_detail` is not supported when reasoning_content is not a string." + ) + + reasoning_start, reasoning_end = self.find_turn( + turns=turns, + turn_idx=index, + tools=tools, + reasoning_only=True, + ) + + if reasoning_start != -1 and reasoning_end != -1: + token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore + reasoning_text, reasoning_train_detail + ) + LOG.debug(f"Reasoning token offsets: {token_offsets}") + for i, offset in enumerate(token_offsets): + if offset != IGNORE_TOKEN_ID and reasoning_start + i < len( + input_ids + ): + labels[reasoning_start + i] = input_ids[reasoning_start + i] + LOG.debug(f"Labels after processing turn {index}: {labels}") # Handle special tokens (EOT and EOS) @@ -611,10 +650,24 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return -1 def find_turn( - self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None + self, + turns: list[dict], + turn_idx: int, + tools: list[dict] | None = None, + content_only: bool = False, + reasoning_only: bool = False, ): """ Locate the starting and ending indices of the specified turn in a conversation. + + Args: + content_only: If True and the turn has reasoning_content (template_thinking_key), + preserve reasoning_content in the dummy turn so the diff only captures the + content field boundaries. This is needed for correct training_detail alignment + when reasoning_content is present. + reasoning_only: If True, preserve content in the dummy turn and replace + reasoning_content with a dummy, so the diff only captures the + reasoning_content field boundaries. """ if turn_idx >= len(turns): @@ -628,10 +681,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): ): return -1, -1 - empty_turn = { - "role": turns[turn_idx].get("role"), - "content": "[[dummy_message]]", - } + thinking_key = self.prompter.template_thinking_key + + if reasoning_only: + # Keep content as-is, replace reasoning with dummy + empty_turn = { + "role": turns[turn_idx].get("role"), + "content": turns[turn_idx].get("content", ""), + } + if thinking_key and thinking_key in turns[turn_idx]: + empty_turn[thinking_key] = "[[dummy_reasoning]]" + else: + empty_turn = { + "role": turns[turn_idx].get("role"), + "content": "[[dummy_message]]", + } + + # When content_only is True, copy reasoning_content to the dummy turn so + # the diff only captures the content field (not reasoning + separator). + if content_only and thinking_key and thinking_key in turns[turn_idx]: + empty_turn[thinking_key] = turns[turn_idx][thinking_key] # Create conversation versions turns_with_empty = turns[:turn_idx] + [empty_turn] @@ -697,6 +766,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return start_idx, end_idx + @staticmethod + def _convert_content_parts( + content, + ) -> tuple[str, list[dict] | None] | None: + """Convert list content to concatenated string + optional training_detail. + + When content is a list of dicts (content parts), each part can specify: + - ``text``, ``content``, or ``value``: the text string + - ``train`` (bool) or ``weight`` (0/1): per-part training flag + + Returns ``(concatenated_text, training_details_or_None)`` if content was + a list, or ``None`` if content was not a list (no conversion needed). + + .. note:: + **Whitespace at part boundaries matters.** BPE tokenizers prepend + spaces to word tokens (e.g. ``" answer"`` is one token). Always + split BEFORE spaces:: + + GOOD: ["Let me think...", " The answer is 4."] + BAD: ["Let me think... ", "The answer is 4."] + + Tokens that straddle a boundary are conservatively masked. + Newlines typically merge with preceding punctuation (``":\\n"`` is + one token), so keep newlines with the preceding part. + """ + if not isinstance(content, list): + return None + + text_parts: list[str] = [] + training_details: list[dict] = [] + has_explicit_training = False + offset = 0 + + for part in content: + if isinstance(part, dict): + # Extract text (HF uses "text", also support "content"/"value") + text = ( + part.get("text") or part.get("content") or part.get("value") or "" + ) + text_parts.append(text) + + # Check for per-part training flags + part_train = part.get("train") + part_weight = part.get("weight") + if part_train is not None or part_weight is not None: + has_explicit_training = True + train = ( + part_train + if part_train is not None + else (part_weight not in (0, 0.0)) + ) + else: + train = True # default trainable, gated by turn-level should_train + + if text: + training_details.append( + { + "begin_offset": offset, + "end_offset": offset + len(text) - 1, + "train": train, + } + ) + offset += len(text) + + # Warn about trailing whitespace at boundaries between parts with + # different training flags — this almost always causes token straddling + if has_explicit_training and len(training_details) > 1: + for i in range(len(training_details) - 1): + cur = training_details[i] + nxt = training_details[i + 1] + if cur["train"] != nxt["train"]: + boundary_text = text_parts[i] + if boundary_text and boundary_text[-1] in (" ", "\t"): + LOG.warning( + "Content part %d ends with whitespace at a train/mask boundary. " + "BPE tokenizers typically prepend spaces to word tokens, so " + "the space will merge with the next part's first word and the " + "resulting token will be MASKED (not trained). Move the " + "whitespace to the start of the next content part instead. " + "Part text: %r", + i, + boundary_text[-20:], + ) + + concatenated = "".join(text_parts) + details = training_details if has_explicit_training else None + return concatenated, details + def get_conversation_thread(self, prompt): turns = [] @@ -723,6 +880,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if training_detail is not None: turn["training_detail"] = training_detail + # Convert list content/reasoning_content to string + auto-generated + # training_detail. See _convert_content_parts for whitespace guidance. + content_result = self._convert_content_parts(turn.get("content")) + if content_result is not None: + turn["content"] = content_result[0] + if content_result[1] is not None: + turn["training_detail"] = content_result[1] + + # Also convert reasoning_content (template_thinking_key) if it's a list + thinking_key = self.prompter.template_thinking_key + if thinking_key and thinking_key in turn: + reasoning_result = self._convert_content_parts(turn[thinking_key]) + if reasoning_result is not None: + turn[thinking_key] = reasoning_result[0] + if reasoning_result[1] is not None: + turn["reasoning_training_detail"] = reasoning_result[1] + turns.append(turn) if self.prompter.drop_system_message and turns[0]["role"] == "system": diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 7d4e6883f..701e4d01e 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -916,6 +916,235 @@ class TestChatTemplateConfigurations: LOG.debug(f"Final labels: {labels}") LOG.debug(f"Final input_ids: {input_ids}") + @enable_hf_offline + def test_content_parts_training( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + request, + ): + LOG.info("Testing with content as list of parts with per-part training") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Dataset where assistant content is a list of parts with per-part training + conversation = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are an AI assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is 2+2?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think...", "train": False}, + {"type": "text", "text": "The answer is 4.", "train": True}, + ], + }, + ] + + dataset = Dataset.from_dict({"messages": [conversation]}) + res = strategy.tokenize_prompt(dataset[0]) + turns = strategy.get_conversation_thread(dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Find the assistant turn (last turn) + assistant_turn_idx = len(turns) - 1 + start_idx, end_idx = strategy.find_turn( + turns=turns, turn_idx=assistant_turn_idx + ) + + assert start_idx != -1 and end_idx != -1, ( + "Could not find assistant turn boundaries" + ) + + decoded = tokenizer.decode(input_ids[start_idx:end_idx]) + LOG.debug(f"Assistant turn decoded: {decoded}") + + # Tokenize each part separately to find their boundaries + part1_text = "Let me think..." + part2_text = "The answer is 4." + + # Verify the concatenated content is in the decoded output + assert part1_text in decoded, ( + f"Part 1 '{part1_text}' not found in decoded: {decoded}" + ) + assert part2_text in decoded, ( + f"Part 2 '{part2_text}' not found in decoded: {decoded}" + ) + + # Verify that part1 tokens (train=False) are masked + # and part2 tokens (train=True) are labeled + turn_labels = labels[start_idx:end_idx] + + # Find where part2 starts in the token sequence + part1_tokens = tokenizer(part1_text, add_special_tokens=False)["input_ids"] + part2_tokens = tokenizer(part2_text, add_special_tokens=False)["input_ids"] + + # The first part should be masked (all IGNORE_TOKEN_ID) + # Due to token boundary alignment, check that at least the interior tokens + # of part1 are masked + assert any(label == IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected some masked labels for train=False part, but got {turn_labels}" + ) + + # The second part should be trained (not IGNORE_TOKEN_ID) + assert any(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Expected some trained labels for train=True part, but got {turn_labels}" + ) + + # More precise check: first N tokens should be masked, last M tokens should be trained + # where N ~ len(part1_tokens) and M ~ len(part2_tokens) + # Allow for token boundary effects at the boundary + num_masked = sum(1 for label in turn_labels if label == IGNORE_TOKEN_ID) + num_trained = sum(1 for label in turn_labels if label != IGNORE_TOKEN_ID) + + LOG.debug(f"Turn labels: {turn_labels}") + LOG.debug(f"Masked tokens: {num_masked}, Trained tokens: {num_trained}") + LOG.debug( + f"Part1 tokens: {len(part1_tokens)}, Part2 tokens: {len(part2_tokens)}" + ) + + # The number of masked tokens should be roughly the size of part1 + # and the number of trained tokens should be roughly the size of part2 + assert num_masked > 0, "Expected masked tokens for the train=False part" + assert num_trained > 0, "Expected trained tokens for the train=True part" + + @enable_hf_offline + def test_content_parts_with_weight( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + request, + ): + LOG.info("Testing with content parts using weight field") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Dataset using weight instead of train + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Thinking step by step: ", "weight": 0}, + {"type": "text", "text": "Hello! How can I help?", "weight": 1}, + ], + }, + ] + + dataset = Dataset.from_dict({"messages": [conversation]}) + res = strategy.tokenize_prompt(dataset[0]) + labels = res["labels"] + + # There should be both masked and trained labels + has_masked = any(label == IGNORE_TOKEN_ID for label in labels) + has_trained = any(label != IGNORE_TOKEN_ID for label in labels) + assert has_masked, "Expected masked tokens (weight=0 part + user turn)" + assert has_trained, "Expected trained tokens (weight=1 part)" + + @enable_hf_offline + def test_content_parts_string_passthrough( + self, + tokenizer, + chat_template, + chat_template_jinja, + eos_token, + request, + ): + LOG.info("Testing that string content still works alongside list content") + + tokenizer, chat_template_jinja = self.setup_tokenizer( + tokenizer, chat_template, chat_template_jinja, eos_token, request + ) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template( + chat_template, jinja_template=chat_template_jinja + ), + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + + # All list content in the conversation + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is 2+2?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The answer is 4.", "train": True}, + ], + }, + ] + + dataset = Dataset.from_dict({"messages": [conversation]}) + res = strategy.tokenize_prompt(dataset[0]) + + # Should tokenize without errors + assert "input_ids" in res + assert "labels" in res + assert len(res["input_ids"]) > 0 + def test_get_chat_template_variables( self, tokenizer, chat_template, chat_template_jinja, eos_token, request ): @@ -1428,3 +1657,250 @@ class TestChatTemplateToolCalling: assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( f"Assistant turn {i} should be unmasked" ) + + +class TestChatTemplateReasoningContent: + """ + Test class for reasoning_content with content parts. + """ + + @enable_hf_offline + def test_reasoning_content_with_content_parts(self, qwen3_tokenizer): + """Test that reasoning_content as string + content as list parts works correctly. + Content training_detail offsets should align with content-only boundaries.""" + LOG.info("Testing reasoning_content with content parts on qwen3") + + tokenizer = deepcopy(qwen3_tokenizer) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template("qwen3"), + message_property_mappings={ + "role": "role", + "content": "content", + "reasoning_content": "reasoning_content", + }, + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + + # reasoning_content is a plain string, content is list with per-part training + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": "What is 2+2?"}], + }, + { + "role": "assistant", + "reasoning_content": "Step 1: 2+2=4", + "content": [ + {"type": "text", "text": "The answer is 4.", "train": True}, + ], + }, + ] + + dataset = Dataset.from_dict({"messages": [conversation]}) + res = strategy.tokenize_prompt(dataset[0]) + turns = strategy.get_conversation_thread(dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Find the assistant turn + assistant_idx = 1 + start_idx, end_idx = strategy.find_turn( + turns=turns, turn_idx=assistant_idx, content_only=True + ) + + assert start_idx != -1 and end_idx != -1, ( + "Could not find assistant content boundaries" + ) + + # The content-only span should contain "The answer is 4." but NOT "Step 1: 2+2=4" + decoded_span = tokenizer.decode(input_ids[start_idx:end_idx]) + assert "The answer is 4." in decoded_span, ( + f"Content not found in span: {decoded_span}" + ) + assert "Step 1" not in decoded_span, ( + f"Reasoning should not be in content-only span: {decoded_span}" + ) + + # Verify that content tokens are trained + content_labels = labels[start_idx:end_idx] + assert any(label != IGNORE_TOKEN_ID for label in content_labels), ( + f"Expected trained labels in content span, got {content_labels}" + ) + + @enable_hf_offline + def test_reasoning_content_per_part_masking(self, qwen3_tokenizer): + """Test masking incorrect reasoning while training on self-correction. + This is the core use case: mask out wrong thoughts, train on corrections.""" + LOG.info("Testing reasoning_content per-part masking on qwen3") + + tokenizer = deepcopy(qwen3_tokenizer) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template("qwen3"), + message_property_mappings={ + "role": "role", + "content": "content", + "reasoning_content": "reasoning_content", + }, + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Reasoning has wrong step (masked) then self-correction (trained) + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": "What is 2+2?"}], + }, + { + "role": "assistant", + "reasoning_content": [ + {"type": "text", "text": "Hmm maybe 2+2=5.", "train": False}, + {"type": "text", "text": " Wait no, 2+2=4.", "train": True}, + ], + "content": [ + {"type": "text", "text": "The answer is 4.", "train": True}, + ], + }, + ] + + dataset = Dataset.from_dict({"messages": [conversation]}) + res = strategy.tokenize_prompt(dataset[0]) + turns = strategy.get_conversation_thread(dataset[0]) + labels = res["labels"] + input_ids = res["input_ids"] + + # Find reasoning boundaries + reasoning_start, reasoning_end = strategy.find_turn( + turns=turns, turn_idx=1, reasoning_only=True + ) + assert reasoning_start != -1 and reasoning_end != -1, ( + "Could not find reasoning boundaries" + ) + + decoded_reasoning = tokenizer.decode(input_ids[reasoning_start:reasoning_end]) + LOG.debug(f"Reasoning span: {decoded_reasoning!r}") + assert "2+2=5" in decoded_reasoning, ( + f"Wrong step not in reasoning span: {decoded_reasoning}" + ) + assert "2+2=4" in decoded_reasoning, ( + f"Correction not in reasoning span: {decoded_reasoning}" + ) + + # Verify reasoning labels have both masked and trained tokens + reasoning_labels = labels[reasoning_start:reasoning_end] + reasoning_ids = input_ids[reasoning_start:reasoning_end] + + # Decode only the trained tokens — should be exactly the self-correction + trained_ids = [ + tid + for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True) + if lab != IGNORE_TOKEN_ID + ] + trained_text = tokenizer.decode(trained_ids) + assert trained_text.strip() == "Wait no, 2+2=4.", ( + f"Expected trained reasoning to be 'Wait no, 2+2=4.', got: {trained_text!r}" + ) + + # Decode only the masked tokens — should be exactly the incorrect step + masked_ids = [ + tid + for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True) + if lab == IGNORE_TOKEN_ID + ] + masked_text = tokenizer.decode(masked_ids) + assert masked_text.strip() == "Hmm maybe 2+2=5.", ( + f"Expected masked reasoning to be 'Hmm maybe 2+2=5.', got: {masked_text!r}" + ) + + # Find content boundaries + content_start, content_end = strategy.find_turn( + turns=turns, turn_idx=1, content_only=True + ) + assert content_start != -1 and content_end != -1, ( + "Could not find content boundaries" + ) + + # Content should be fully trained — decode trained tokens to verify + content_labels = labels[content_start:content_end] + content_ids = input_ids[content_start:content_end] + content_trained_ids = [ + tid + for tid, lab in zip(content_ids, content_labels, strict=True) + if lab != IGNORE_TOKEN_ID + ] + content_trained_text = tokenizer.decode(content_trained_ids) + assert "The answer is 4." in content_trained_text, ( + f"Expected 'The answer is 4.' in trained content tokens, " + f"got: {content_trained_text!r}" + ) + assert all(label != IGNORE_TOKEN_ID for label in content_labels), ( + f"Expected all content labels trained, got {content_labels}" + ) + + @enable_hf_offline + def test_reasoning_content_as_list_no_training_flags(self, qwen3_tokenizer): + """Test that reasoning_content as list without training flags still works.""" + LOG.info("Testing reasoning_content as list without training flags on qwen3") + + tokenizer = deepcopy(qwen3_tokenizer) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template("qwen3"), + message_property_mappings={ + "role": "role", + "content": "content", + "reasoning_content": "reasoning_content", + }, + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Both as lists, no per-part training flags + conversation = [ + { + "role": "user", + "content": [{"type": "text", "text": "What is 2+2?"}], + }, + { + "role": "assistant", + "reasoning_content": [ + {"type": "text", "text": "Step 1: addition."}, + {"type": "text", "text": " Step 2: 2+2=4."}, + ], + "content": [ + {"type": "text", "text": "The answer is 4."}, + ], + }, + ] + + dataset = Dataset.from_dict({"messages": [conversation]}) + res = strategy.tokenize_prompt(dataset[0]) + + # Should tokenize without errors + assert "input_ids" in res + assert "labels" in res + assert len(res["input_ids"]) > 0 + + # Verify the full output contains both reasoning and content + full_text = tokenizer.decode(res["input_ids"]) + assert "Step 1: addition." in full_text + assert "Step 2: 2+2=4." in full_text + assert "The answer is 4." in full_text