handle trainable/masked spans in content and reasoning content (#3592)

This commit is contained in:
Wing Lian
2026-04-10 14:11:10 -04:00
committed by GitHub
parent e7a6a5b529
commit 315cdeede9
3 changed files with 767 additions and 10 deletions

View File

@@ -471,6 +471,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
reasoning_train_detail = turn.get("reasoning_training_detail")
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
@@ -479,8 +480,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
elif train_detail is not None or reasoning_train_detail is not None:
should_train = bool(train_detail) or bool(reasoning_train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
@@ -500,15 +501,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
thinking_key = self.prompter.template_thinking_key
has_reasoning = thinking_key and turn.get(thinking_key) is not None
has_any_detail = train_detail or reasoning_train_detail
# When train_detail is present and the turn has reasoning_content,
# use content_only=True so find_turn returns content-only boundaries
# (excluding reasoning_content + template separator tokens).
use_content_only = bool(has_any_detail and has_reasoning)
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
turns=turns,
turn_idx=index,
tools=tools,
content_only=use_content_only,
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
# Block multi-content for now
if not isinstance(content, str):
raise ValueError(
"`train_detail` is not supported when `content` is not a string."
@@ -526,7 +538,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
elif not reasoning_train_detail:
# No per-part detail on either field — train the whole span
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
@@ -534,6 +547,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
# Handle reasoning_content training_detail separately
if should_train and reasoning_train_detail and has_reasoning:
reasoning_text = turn[thinking_key]
if not isinstance(reasoning_text, str):
raise ValueError(
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
)
reasoning_start, reasoning_end = self.find_turn(
turns=turns,
turn_idx=index,
tools=tools,
reasoning_only=True,
)
if reasoning_start != -1 and reasoning_end != -1:
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
reasoning_text, reasoning_train_detail
)
LOG.debug(f"Reasoning token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
input_ids
):
labels[reasoning_start + i] = input_ids[reasoning_start + i]
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle special tokens (EOT and EOS)
@@ -611,10 +650,24 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return -1
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
self,
turns: list[dict],
turn_idx: int,
tools: list[dict] | None = None,
content_only: bool = False,
reasoning_only: bool = False,
):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
content_only: If True and the turn has reasoning_content (template_thinking_key),
preserve reasoning_content in the dummy turn so the diff only captures the
content field boundaries. This is needed for correct training_detail alignment
when reasoning_content is present.
reasoning_only: If True, preserve content in the dummy turn and replace
reasoning_content with a dummy, so the diff only captures the
reasoning_content field boundaries.
"""
if turn_idx >= len(turns):
@@ -628,10 +681,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
):
return -1, -1
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
thinking_key = self.prompter.template_thinking_key
if reasoning_only:
# Keep content as-is, replace reasoning with dummy
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": turns[turn_idx].get("content", ""),
}
if thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = "[[dummy_reasoning]]"
else:
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
# When content_only is True, copy reasoning_content to the dummy turn so
# the diff only captures the content field (not reasoning + separator).
if content_only and thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
# Create conversation versions
turns_with_empty = turns[:turn_idx] + [empty_turn]
@@ -697,6 +766,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return start_idx, end_idx
@staticmethod
def _convert_content_parts(
content,
) -> tuple[str, list[dict] | None] | None:
"""Convert list content to concatenated string + optional training_detail.
When content is a list of dicts (content parts), each part can specify:
- ``text``, ``content``, or ``value``: the text string
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
Returns ``(concatenated_text, training_details_or_None)`` if content was
a list, or ``None`` if content was not a list (no conversion needed).
.. note::
**Whitespace at part boundaries matters.** BPE tokenizers prepend
spaces to word tokens (e.g. ``" answer"`` is one token). Always
split BEFORE spaces::
GOOD: ["Let me think...", " The answer is 4."]
BAD: ["Let me think... ", "The answer is 4."]
Tokens that straddle a boundary are conservatively masked.
Newlines typically merge with preceding punctuation (``":\\n"`` is
one token), so keep newlines with the preceding part.
"""
if not isinstance(content, list):
return None
text_parts: list[str] = []
training_details: list[dict] = []
has_explicit_training = False
offset = 0
for part in content:
if isinstance(part, dict):
# Extract text (HF uses "text", also support "content"/"value")
text = (
part.get("text") or part.get("content") or part.get("value") or ""
)
text_parts.append(text)
# Check for per-part training flags
part_train = part.get("train")
part_weight = part.get("weight")
if part_train is not None or part_weight is not None:
has_explicit_training = True
train = (
part_train
if part_train is not None
else (part_weight not in (0, 0.0))
)
else:
train = True # default trainable, gated by turn-level should_train
if text:
training_details.append(
{
"begin_offset": offset,
"end_offset": offset + len(text) - 1,
"train": train,
}
)
offset += len(text)
# Warn about trailing whitespace at boundaries between parts with
# different training flags — this almost always causes token straddling
if has_explicit_training and len(training_details) > 1:
for i in range(len(training_details) - 1):
cur = training_details[i]
nxt = training_details[i + 1]
if cur["train"] != nxt["train"]:
boundary_text = text_parts[i]
if boundary_text and boundary_text[-1] in (" ", "\t"):
LOG.warning(
"Content part %d ends with whitespace at a train/mask boundary. "
"BPE tokenizers typically prepend spaces to word tokens, so "
"the space will merge with the next part's first word and the "
"resulting token will be MASKED (not trained). Move the "
"whitespace to the start of the next content part instead. "
"Part text: %r",
i,
boundary_text[-20:],
)
concatenated = "".join(text_parts)
details = training_details if has_explicit_training else None
return concatenated, details
def get_conversation_thread(self, prompt):
turns = []
@@ -723,6 +880,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if training_detail is not None:
turn["training_detail"] = training_detail
# Convert list content/reasoning_content to string + auto-generated
# training_detail. See _convert_content_parts for whitespace guidance.
content_result = self._convert_content_parts(turn.get("content"))
if content_result is not None:
turn["content"] = content_result[0]
if content_result[1] is not None:
turn["training_detail"] = content_result[1]
# Also convert reasoning_content (template_thinking_key) if it's a list
thinking_key = self.prompter.template_thinking_key
if thinking_key and thinking_key in turn:
reasoning_result = self._convert_content_parts(turn[thinking_key])
if reasoning_result is not None:
turn[thinking_key] = reasoning_result[0]
if reasoning_result[1] is not None:
turn["reasoning_training_detail"] = reasoning_result[1]
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":