handle trainable/masked spans in content and reasoning content (#3592)
This commit is contained in:
@@ -302,6 +302,113 @@ datasets:
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
#### Content parts with per-part training control
|
||||
|
||||
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me think step by step...", "train": false},
|
||||
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The configuration is the same as standard `chat_template` — no extra fields needed:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
Each content part supports:
|
||||
|
||||
- `type`: `"text"` (required)
|
||||
- `text`: the text value (also accepts `content` or `value` as the key)
|
||||
- `train`: `true`/`false` (optional) — whether to train on this part
|
||||
- `weight`: `0`/`1` (optional) — alternative to `train`
|
||||
|
||||
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
|
||||
|
||||
::: {.callout-warning title="Whitespace at part boundaries"}
|
||||
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
|
||||
|
||||
**Split BEFORE spaces** (space goes with the next part):
|
||||
|
||||
```json
|
||||
[
|
||||
{"type": "text", "text": "Let me think...", "train": false},
|
||||
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||
]
|
||||
```
|
||||
|
||||
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
|
||||
|
||||
```json
|
||||
[
|
||||
{"type": "text", "text": "Let me think... ", "train": false},
|
||||
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||
]
|
||||
```
|
||||
|
||||
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
|
||||
|
||||
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
|
||||
|
||||
```json
|
||||
[
|
||||
{"type": "text", "text": "Thinking:\n", "train": false},
|
||||
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||
]
|
||||
```
|
||||
|
||||
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
|
||||
:::
|
||||
|
||||
##### Per-part training on reasoning_content
|
||||
|
||||
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": [
|
||||
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
|
||||
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
|
||||
],
|
||||
"content": [
|
||||
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
|
||||
|
||||
::: {.callout-tip}
|
||||
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
|
||||
:::
|
||||
|
||||
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
|
||||
|
||||
|
||||
#### Reasoning split
|
||||
|
||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
@@ -471,6 +471,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
content = turn.get("content")
|
||||
train_turn = turn.get("training")
|
||||
train_detail = turn.get("training_detail")
|
||||
reasoning_train_detail = turn.get("reasoning_training_detail")
|
||||
|
||||
LOG.debug(
|
||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||
@@ -479,8 +480,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
should_train = None
|
||||
if train_turn is not None:
|
||||
should_train = train_turn
|
||||
elif train_detail is not None:
|
||||
should_train = bool(train_detail)
|
||||
elif train_detail is not None or reasoning_train_detail is not None:
|
||||
should_train = bool(train_detail) or bool(reasoning_train_detail)
|
||||
else:
|
||||
should_train = self.train_on_inputs or role in self.roles_to_train
|
||||
|
||||
@@ -500,15 +501,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
continue
|
||||
|
||||
thinking_key = self.prompter.template_thinking_key
|
||||
has_reasoning = thinking_key and turn.get(thinking_key) is not None
|
||||
has_any_detail = train_detail or reasoning_train_detail
|
||||
|
||||
# When train_detail is present and the turn has reasoning_content,
|
||||
# use content_only=True so find_turn returns content-only boundaries
|
||||
# (excluding reasoning_content + template separator tokens).
|
||||
use_content_only = bool(has_any_detail and has_reasoning)
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
turns=turns, turn_idx=index, tools=tools
|
||||
turns=turns,
|
||||
turn_idx=index,
|
||||
tools=tools,
|
||||
content_only=use_content_only,
|
||||
)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||
if train_detail:
|
||||
# Block multi-content for now
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
"`train_detail` is not supported when `content` is not a string."
|
||||
@@ -526,7 +538,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
LOG.debug(
|
||||
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
||||
)
|
||||
else:
|
||||
elif not reasoning_train_detail:
|
||||
# No per-part detail on either field — train the whole span
|
||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||
turn_start_idx:turn_end_idx
|
||||
]
|
||||
@@ -534,6 +547,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
||||
)
|
||||
|
||||
# Handle reasoning_content training_detail separately
|
||||
if should_train and reasoning_train_detail and has_reasoning:
|
||||
reasoning_text = turn[thinking_key]
|
||||
if not isinstance(reasoning_text, str):
|
||||
raise ValueError(
|
||||
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
|
||||
)
|
||||
|
||||
reasoning_start, reasoning_end = self.find_turn(
|
||||
turns=turns,
|
||||
turn_idx=index,
|
||||
tools=tools,
|
||||
reasoning_only=True,
|
||||
)
|
||||
|
||||
if reasoning_start != -1 and reasoning_end != -1:
|
||||
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
|
||||
reasoning_text, reasoning_train_detail
|
||||
)
|
||||
LOG.debug(f"Reasoning token offsets: {token_offsets}")
|
||||
for i, offset in enumerate(token_offsets):
|
||||
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
|
||||
input_ids
|
||||
):
|
||||
labels[reasoning_start + i] = input_ids[reasoning_start + i]
|
||||
|
||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||
|
||||
# Handle special tokens (EOT and EOS)
|
||||
@@ -611,10 +650,24 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return -1
|
||||
|
||||
def find_turn(
|
||||
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
||||
self,
|
||||
turns: list[dict],
|
||||
turn_idx: int,
|
||||
tools: list[dict] | None = None,
|
||||
content_only: bool = False,
|
||||
reasoning_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
|
||||
Args:
|
||||
content_only: If True and the turn has reasoning_content (template_thinking_key),
|
||||
preserve reasoning_content in the dummy turn so the diff only captures the
|
||||
content field boundaries. This is needed for correct training_detail alignment
|
||||
when reasoning_content is present.
|
||||
reasoning_only: If True, preserve content in the dummy turn and replace
|
||||
reasoning_content with a dummy, so the diff only captures the
|
||||
reasoning_content field boundaries.
|
||||
"""
|
||||
|
||||
if turn_idx >= len(turns):
|
||||
@@ -628,11 +681,27 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
thinking_key = self.prompter.template_thinking_key
|
||||
|
||||
if reasoning_only:
|
||||
# Keep content as-is, replace reasoning with dummy
|
||||
empty_turn = {
|
||||
"role": turns[turn_idx].get("role"),
|
||||
"content": turns[turn_idx].get("content", ""),
|
||||
}
|
||||
if thinking_key and thinking_key in turns[turn_idx]:
|
||||
empty_turn[thinking_key] = "[[dummy_reasoning]]"
|
||||
else:
|
||||
empty_turn = {
|
||||
"role": turns[turn_idx].get("role"),
|
||||
"content": "[[dummy_message]]",
|
||||
}
|
||||
|
||||
# When content_only is True, copy reasoning_content to the dummy turn so
|
||||
# the diff only captures the content field (not reasoning + separator).
|
||||
if content_only and thinking_key and thinking_key in turns[turn_idx]:
|
||||
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
|
||||
|
||||
# Create conversation versions
|
||||
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
||||
turns_with_content = turns[: turn_idx + 1]
|
||||
@@ -697,6 +766,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return start_idx, end_idx
|
||||
|
||||
@staticmethod
|
||||
def _convert_content_parts(
|
||||
content,
|
||||
) -> tuple[str, list[dict] | None] | None:
|
||||
"""Convert list content to concatenated string + optional training_detail.
|
||||
|
||||
When content is a list of dicts (content parts), each part can specify:
|
||||
- ``text``, ``content``, or ``value``: the text string
|
||||
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
|
||||
|
||||
Returns ``(concatenated_text, training_details_or_None)`` if content was
|
||||
a list, or ``None`` if content was not a list (no conversion needed).
|
||||
|
||||
.. note::
|
||||
**Whitespace at part boundaries matters.** BPE tokenizers prepend
|
||||
spaces to word tokens (e.g. ``" answer"`` is one token). Always
|
||||
split BEFORE spaces::
|
||||
|
||||
GOOD: ["Let me think...", " The answer is 4."]
|
||||
BAD: ["Let me think... ", "The answer is 4."]
|
||||
|
||||
Tokens that straddle a boundary are conservatively masked.
|
||||
Newlines typically merge with preceding punctuation (``":\\n"`` is
|
||||
one token), so keep newlines with the preceding part.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return None
|
||||
|
||||
text_parts: list[str] = []
|
||||
training_details: list[dict] = []
|
||||
has_explicit_training = False
|
||||
offset = 0
|
||||
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
# Extract text (HF uses "text", also support "content"/"value")
|
||||
text = (
|
||||
part.get("text") or part.get("content") or part.get("value") or ""
|
||||
)
|
||||
text_parts.append(text)
|
||||
|
||||
# Check for per-part training flags
|
||||
part_train = part.get("train")
|
||||
part_weight = part.get("weight")
|
||||
if part_train is not None or part_weight is not None:
|
||||
has_explicit_training = True
|
||||
train = (
|
||||
part_train
|
||||
if part_train is not None
|
||||
else (part_weight not in (0, 0.0))
|
||||
)
|
||||
else:
|
||||
train = True # default trainable, gated by turn-level should_train
|
||||
|
||||
if text:
|
||||
training_details.append(
|
||||
{
|
||||
"begin_offset": offset,
|
||||
"end_offset": offset + len(text) - 1,
|
||||
"train": train,
|
||||
}
|
||||
)
|
||||
offset += len(text)
|
||||
|
||||
# Warn about trailing whitespace at boundaries between parts with
|
||||
# different training flags — this almost always causes token straddling
|
||||
if has_explicit_training and len(training_details) > 1:
|
||||
for i in range(len(training_details) - 1):
|
||||
cur = training_details[i]
|
||||
nxt = training_details[i + 1]
|
||||
if cur["train"] != nxt["train"]:
|
||||
boundary_text = text_parts[i]
|
||||
if boundary_text and boundary_text[-1] in (" ", "\t"):
|
||||
LOG.warning(
|
||||
"Content part %d ends with whitespace at a train/mask boundary. "
|
||||
"BPE tokenizers typically prepend spaces to word tokens, so "
|
||||
"the space will merge with the next part's first word and the "
|
||||
"resulting token will be MASKED (not trained). Move the "
|
||||
"whitespace to the start of the next content part instead. "
|
||||
"Part text: %r",
|
||||
i,
|
||||
boundary_text[-20:],
|
||||
)
|
||||
|
||||
concatenated = "".join(text_parts)
|
||||
details = training_details if has_explicit_training else None
|
||||
return concatenated, details
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
|
||||
@@ -723,6 +880,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if training_detail is not None:
|
||||
turn["training_detail"] = training_detail
|
||||
|
||||
# Convert list content/reasoning_content to string + auto-generated
|
||||
# training_detail. See _convert_content_parts for whitespace guidance.
|
||||
content_result = self._convert_content_parts(turn.get("content"))
|
||||
if content_result is not None:
|
||||
turn["content"] = content_result[0]
|
||||
if content_result[1] is not None:
|
||||
turn["training_detail"] = content_result[1]
|
||||
|
||||
# Also convert reasoning_content (template_thinking_key) if it's a list
|
||||
thinking_key = self.prompter.template_thinking_key
|
||||
if thinking_key and thinking_key in turn:
|
||||
reasoning_result = self._convert_content_parts(turn[thinking_key])
|
||||
if reasoning_result is not None:
|
||||
turn[thinking_key] = reasoning_result[0]
|
||||
if reasoning_result[1] is not None:
|
||||
turn["reasoning_training_detail"] = reasoning_result[1]
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
|
||||
@@ -916,6 +916,235 @@ class TestChatTemplateConfigurations:
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
LOG.debug(f"Final input_ids: {input_ids}")
|
||||
|
||||
@enable_hf_offline
|
||||
def test_content_parts_training(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
request,
|
||||
):
|
||||
LOG.info("Testing with content as list of parts with per-part training")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# Dataset where assistant content is a list of parts with per-part training
|
||||
conversation = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are an AI assistant."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is 2+2?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me think...", "train": False},
|
||||
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||
res = strategy.tokenize_prompt(dataset[0])
|
||||
turns = strategy.get_conversation_thread(dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Find the assistant turn (last turn)
|
||||
assistant_turn_idx = len(turns) - 1
|
||||
start_idx, end_idx = strategy.find_turn(
|
||||
turns=turns, turn_idx=assistant_turn_idx
|
||||
)
|
||||
|
||||
assert start_idx != -1 and end_idx != -1, (
|
||||
"Could not find assistant turn boundaries"
|
||||
)
|
||||
|
||||
decoded = tokenizer.decode(input_ids[start_idx:end_idx])
|
||||
LOG.debug(f"Assistant turn decoded: {decoded}")
|
||||
|
||||
# Tokenize each part separately to find their boundaries
|
||||
part1_text = "Let me think..."
|
||||
part2_text = "The answer is 4."
|
||||
|
||||
# Verify the concatenated content is in the decoded output
|
||||
assert part1_text in decoded, (
|
||||
f"Part 1 '{part1_text}' not found in decoded: {decoded}"
|
||||
)
|
||||
assert part2_text in decoded, (
|
||||
f"Part 2 '{part2_text}' not found in decoded: {decoded}"
|
||||
)
|
||||
|
||||
# Verify that part1 tokens (train=False) are masked
|
||||
# and part2 tokens (train=True) are labeled
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
|
||||
# Find where part2 starts in the token sequence
|
||||
part1_tokens = tokenizer(part1_text, add_special_tokens=False)["input_ids"]
|
||||
part2_tokens = tokenizer(part2_text, add_special_tokens=False)["input_ids"]
|
||||
|
||||
# The first part should be masked (all IGNORE_TOKEN_ID)
|
||||
# Due to token boundary alignment, check that at least the interior tokens
|
||||
# of part1 are masked
|
||||
assert any(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Expected some masked labels for train=False part, but got {turn_labels}"
|
||||
)
|
||||
|
||||
# The second part should be trained (not IGNORE_TOKEN_ID)
|
||||
assert any(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Expected some trained labels for train=True part, but got {turn_labels}"
|
||||
)
|
||||
|
||||
# More precise check: first N tokens should be masked, last M tokens should be trained
|
||||
# where N ~ len(part1_tokens) and M ~ len(part2_tokens)
|
||||
# Allow for token boundary effects at the boundary
|
||||
num_masked = sum(1 for label in turn_labels if label == IGNORE_TOKEN_ID)
|
||||
num_trained = sum(1 for label in turn_labels if label != IGNORE_TOKEN_ID)
|
||||
|
||||
LOG.debug(f"Turn labels: {turn_labels}")
|
||||
LOG.debug(f"Masked tokens: {num_masked}, Trained tokens: {num_trained}")
|
||||
LOG.debug(
|
||||
f"Part1 tokens: {len(part1_tokens)}, Part2 tokens: {len(part2_tokens)}"
|
||||
)
|
||||
|
||||
# The number of masked tokens should be roughly the size of part1
|
||||
# and the number of trained tokens should be roughly the size of part2
|
||||
assert num_masked > 0, "Expected masked tokens for the train=False part"
|
||||
assert num_trained > 0, "Expected trained tokens for the train=True part"
|
||||
|
||||
@enable_hf_offline
|
||||
def test_content_parts_with_weight(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
request,
|
||||
):
|
||||
LOG.info("Testing with content parts using weight field")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# Dataset using weight instead of train
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Thinking step by step: ", "weight": 0},
|
||||
{"type": "text", "text": "Hello! How can I help?", "weight": 1},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||
res = strategy.tokenize_prompt(dataset[0])
|
||||
labels = res["labels"]
|
||||
|
||||
# There should be both masked and trained labels
|
||||
has_masked = any(label == IGNORE_TOKEN_ID for label in labels)
|
||||
has_trained = any(label != IGNORE_TOKEN_ID for label in labels)
|
||||
assert has_masked, "Expected masked tokens (weight=0 part + user turn)"
|
||||
assert has_trained, "Expected trained tokens (weight=1 part)"
|
||||
|
||||
@enable_hf_offline
|
||||
def test_content_parts_string_passthrough(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
request,
|
||||
):
|
||||
LOG.info("Testing that string content still works alongside list content")
|
||||
|
||||
tokenizer, chat_template_jinja = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# All list content in the conversation
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is 2+2?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||
res = strategy.tokenize_prompt(dataset[0])
|
||||
|
||||
# Should tokenize without errors
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert len(res["input_ids"]) > 0
|
||||
|
||||
def test_get_chat_template_variables(
|
||||
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
):
|
||||
@@ -1428,3 +1657,250 @@ class TestChatTemplateToolCalling:
|
||||
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Assistant turn {i} should be unmasked"
|
||||
)
|
||||
|
||||
|
||||
class TestChatTemplateReasoningContent:
|
||||
"""
|
||||
Test class for reasoning_content with content parts.
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_reasoning_content_with_content_parts(self, qwen3_tokenizer):
|
||||
"""Test that reasoning_content as string + content as list parts works correctly.
|
||||
Content training_detail offsets should align with content-only boundaries."""
|
||||
LOG.info("Testing reasoning_content with content parts on qwen3")
|
||||
|
||||
tokenizer = deepcopy(qwen3_tokenizer)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template("qwen3"),
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
"reasoning_content": "reasoning_content",
|
||||
},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# reasoning_content is a plain string, content is list with per-part training
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is 2+2?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "Step 1: 2+2=4",
|
||||
"content": [
|
||||
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||
res = strategy.tokenize_prompt(dataset[0])
|
||||
turns = strategy.get_conversation_thread(dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Find the assistant turn
|
||||
assistant_idx = 1
|
||||
start_idx, end_idx = strategy.find_turn(
|
||||
turns=turns, turn_idx=assistant_idx, content_only=True
|
||||
)
|
||||
|
||||
assert start_idx != -1 and end_idx != -1, (
|
||||
"Could not find assistant content boundaries"
|
||||
)
|
||||
|
||||
# The content-only span should contain "The answer is 4." but NOT "Step 1: 2+2=4"
|
||||
decoded_span = tokenizer.decode(input_ids[start_idx:end_idx])
|
||||
assert "The answer is 4." in decoded_span, (
|
||||
f"Content not found in span: {decoded_span}"
|
||||
)
|
||||
assert "Step 1" not in decoded_span, (
|
||||
f"Reasoning should not be in content-only span: {decoded_span}"
|
||||
)
|
||||
|
||||
# Verify that content tokens are trained
|
||||
content_labels = labels[start_idx:end_idx]
|
||||
assert any(label != IGNORE_TOKEN_ID for label in content_labels), (
|
||||
f"Expected trained labels in content span, got {content_labels}"
|
||||
)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_reasoning_content_per_part_masking(self, qwen3_tokenizer):
|
||||
"""Test masking incorrect reasoning while training on self-correction.
|
||||
This is the core use case: mask out wrong thoughts, train on corrections."""
|
||||
LOG.info("Testing reasoning_content per-part masking on qwen3")
|
||||
|
||||
tokenizer = deepcopy(qwen3_tokenizer)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template("qwen3"),
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
"reasoning_content": "reasoning_content",
|
||||
},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# Reasoning has wrong step (masked) then self-correction (trained)
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is 2+2?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": [
|
||||
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": False},
|
||||
{"type": "text", "text": " Wait no, 2+2=4.", "train": True},
|
||||
],
|
||||
"content": [
|
||||
{"type": "text", "text": "The answer is 4.", "train": True},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||
res = strategy.tokenize_prompt(dataset[0])
|
||||
turns = strategy.get_conversation_thread(dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
# Find reasoning boundaries
|
||||
reasoning_start, reasoning_end = strategy.find_turn(
|
||||
turns=turns, turn_idx=1, reasoning_only=True
|
||||
)
|
||||
assert reasoning_start != -1 and reasoning_end != -1, (
|
||||
"Could not find reasoning boundaries"
|
||||
)
|
||||
|
||||
decoded_reasoning = tokenizer.decode(input_ids[reasoning_start:reasoning_end])
|
||||
LOG.debug(f"Reasoning span: {decoded_reasoning!r}")
|
||||
assert "2+2=5" in decoded_reasoning, (
|
||||
f"Wrong step not in reasoning span: {decoded_reasoning}"
|
||||
)
|
||||
assert "2+2=4" in decoded_reasoning, (
|
||||
f"Correction not in reasoning span: {decoded_reasoning}"
|
||||
)
|
||||
|
||||
# Verify reasoning labels have both masked and trained tokens
|
||||
reasoning_labels = labels[reasoning_start:reasoning_end]
|
||||
reasoning_ids = input_ids[reasoning_start:reasoning_end]
|
||||
|
||||
# Decode only the trained tokens — should be exactly the self-correction
|
||||
trained_ids = [
|
||||
tid
|
||||
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
|
||||
if lab != IGNORE_TOKEN_ID
|
||||
]
|
||||
trained_text = tokenizer.decode(trained_ids)
|
||||
assert trained_text.strip() == "Wait no, 2+2=4.", (
|
||||
f"Expected trained reasoning to be 'Wait no, 2+2=4.', got: {trained_text!r}"
|
||||
)
|
||||
|
||||
# Decode only the masked tokens — should be exactly the incorrect step
|
||||
masked_ids = [
|
||||
tid
|
||||
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
|
||||
if lab == IGNORE_TOKEN_ID
|
||||
]
|
||||
masked_text = tokenizer.decode(masked_ids)
|
||||
assert masked_text.strip() == "Hmm maybe 2+2=5.", (
|
||||
f"Expected masked reasoning to be 'Hmm maybe 2+2=5.', got: {masked_text!r}"
|
||||
)
|
||||
|
||||
# Find content boundaries
|
||||
content_start, content_end = strategy.find_turn(
|
||||
turns=turns, turn_idx=1, content_only=True
|
||||
)
|
||||
assert content_start != -1 and content_end != -1, (
|
||||
"Could not find content boundaries"
|
||||
)
|
||||
|
||||
# Content should be fully trained — decode trained tokens to verify
|
||||
content_labels = labels[content_start:content_end]
|
||||
content_ids = input_ids[content_start:content_end]
|
||||
content_trained_ids = [
|
||||
tid
|
||||
for tid, lab in zip(content_ids, content_labels, strict=True)
|
||||
if lab != IGNORE_TOKEN_ID
|
||||
]
|
||||
content_trained_text = tokenizer.decode(content_trained_ids)
|
||||
assert "The answer is 4." in content_trained_text, (
|
||||
f"Expected 'The answer is 4.' in trained content tokens, "
|
||||
f"got: {content_trained_text!r}"
|
||||
)
|
||||
assert all(label != IGNORE_TOKEN_ID for label in content_labels), (
|
||||
f"Expected all content labels trained, got {content_labels}"
|
||||
)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_reasoning_content_as_list_no_training_flags(self, qwen3_tokenizer):
|
||||
"""Test that reasoning_content as list without training flags still works."""
|
||||
LOG.info("Testing reasoning_content as list without training flags on qwen3")
|
||||
|
||||
tokenizer = deepcopy(qwen3_tokenizer)
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template("qwen3"),
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
"reasoning_content": "reasoning_content",
|
||||
},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# Both as lists, no per-part training flags
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is 2+2?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": [
|
||||
{"type": "text", "text": "Step 1: addition."},
|
||||
{"type": "text", "text": " Step 2: 2+2=4."},
|
||||
],
|
||||
"content": [
|
||||
{"type": "text", "text": "The answer is 4."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dataset = Dataset.from_dict({"messages": [conversation]})
|
||||
res = strategy.tokenize_prompt(dataset[0])
|
||||
|
||||
# Should tokenize without errors
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert len(res["input_ids"]) > 0
|
||||
|
||||
# Verify the full output contains both reasoning and content
|
||||
full_text = tokenizer.decode(res["input_ids"])
|
||||
assert "Step 1: addition." in full_text
|
||||
assert "Step 2: 2+2=4." in full_text
|
||||
assert "The answer is 4." in full_text
|
||||
|
||||
Reference in New Issue
Block a user