handle trainable/masked spans in content and reasoning content (#3592)

This commit is contained in:
Wing Lian
2026-04-10 14:11:10 -04:00
committed by GitHub
parent e7a6a5b529
commit 315cdeede9
3 changed files with 767 additions and 10 deletions

View File

@@ -916,6 +916,235 @@ class TestChatTemplateConfigurations:
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
@enable_hf_offline
def test_content_parts_training(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing with content as list of parts with per-part training")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Dataset where assistant content is a list of parts with per-part training
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are an AI assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is 2+2?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think...", "train": False},
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find the assistant turn (last turn)
assistant_turn_idx = len(turns) - 1
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=assistant_turn_idx
)
assert start_idx != -1 and end_idx != -1, (
"Could not find assistant turn boundaries"
)
decoded = tokenizer.decode(input_ids[start_idx:end_idx])
LOG.debug(f"Assistant turn decoded: {decoded}")
# Tokenize each part separately to find their boundaries
part1_text = "Let me think..."
part2_text = "The answer is 4."
# Verify the concatenated content is in the decoded output
assert part1_text in decoded, (
f"Part 1 '{part1_text}' not found in decoded: {decoded}"
)
assert part2_text in decoded, (
f"Part 2 '{part2_text}' not found in decoded: {decoded}"
)
# Verify that part1 tokens (train=False) are masked
# and part2 tokens (train=True) are labeled
turn_labels = labels[start_idx:end_idx]
# Find where part2 starts in the token sequence
part1_tokens = tokenizer(part1_text, add_special_tokens=False)["input_ids"]
part2_tokens = tokenizer(part2_text, add_special_tokens=False)["input_ids"]
# The first part should be masked (all IGNORE_TOKEN_ID)
# Due to token boundary alignment, check that at least the interior tokens
# of part1 are masked
assert any(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected some masked labels for train=False part, but got {turn_labels}"
)
# The second part should be trained (not IGNORE_TOKEN_ID)
assert any(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected some trained labels for train=True part, but got {turn_labels}"
)
# More precise check: first N tokens should be masked, last M tokens should be trained
# where N ~ len(part1_tokens) and M ~ len(part2_tokens)
# Allow for token boundary effects at the boundary
num_masked = sum(1 for label in turn_labels if label == IGNORE_TOKEN_ID)
num_trained = sum(1 for label in turn_labels if label != IGNORE_TOKEN_ID)
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Masked tokens: {num_masked}, Trained tokens: {num_trained}")
LOG.debug(
f"Part1 tokens: {len(part1_tokens)}, Part2 tokens: {len(part2_tokens)}"
)
# The number of masked tokens should be roughly the size of part1
# and the number of trained tokens should be roughly the size of part2
assert num_masked > 0, "Expected masked tokens for the train=False part"
assert num_trained > 0, "Expected trained tokens for the train=True part"
@enable_hf_offline
def test_content_parts_with_weight(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing with content parts using weight field")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Dataset using weight instead of train
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Thinking step by step: ", "weight": 0},
{"type": "text", "text": "Hello! How can I help?", "weight": 1},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
labels = res["labels"]
# There should be both masked and trained labels
has_masked = any(label == IGNORE_TOKEN_ID for label in labels)
has_trained = any(label != IGNORE_TOKEN_ID for label in labels)
assert has_masked, "Expected masked tokens (weight=0 part + user turn)"
assert has_trained, "Expected trained tokens (weight=1 part)"
@enable_hf_offline
def test_content_parts_string_passthrough(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing that string content still works alongside list content")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# All list content in the conversation
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is 2+2?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
# Should tokenize without errors
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
def test_get_chat_template_variables(
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
):
@@ -1428,3 +1657,250 @@ class TestChatTemplateToolCalling:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Assistant turn {i} should be unmasked"
)
class TestChatTemplateReasoningContent:
"""
Test class for reasoning_content with content parts.
"""
@enable_hf_offline
def test_reasoning_content_with_content_parts(self, qwen3_tokenizer):
"""Test that reasoning_content as string + content as list parts works correctly.
Content training_detail offsets should align with content-only boundaries."""
LOG.info("Testing reasoning_content with content parts on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# reasoning_content is a plain string, content is list with per-part training
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": "Step 1: 2+2=4",
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find the assistant turn
assistant_idx = 1
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=assistant_idx, content_only=True
)
assert start_idx != -1 and end_idx != -1, (
"Could not find assistant content boundaries"
)
# The content-only span should contain "The answer is 4." but NOT "Step 1: 2+2=4"
decoded_span = tokenizer.decode(input_ids[start_idx:end_idx])
assert "The answer is 4." in decoded_span, (
f"Content not found in span: {decoded_span}"
)
assert "Step 1" not in decoded_span, (
f"Reasoning should not be in content-only span: {decoded_span}"
)
# Verify that content tokens are trained
content_labels = labels[start_idx:end_idx]
assert any(label != IGNORE_TOKEN_ID for label in content_labels), (
f"Expected trained labels in content span, got {content_labels}"
)
@enable_hf_offline
def test_reasoning_content_per_part_masking(self, qwen3_tokenizer):
"""Test masking incorrect reasoning while training on self-correction.
This is the core use case: mask out wrong thoughts, train on corrections."""
LOG.info("Testing reasoning_content per-part masking on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Reasoning has wrong step (masked) then self-correction (trained)
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": False},
{"type": "text", "text": " Wait no, 2+2=4.", "train": True},
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find reasoning boundaries
reasoning_start, reasoning_end = strategy.find_turn(
turns=turns, turn_idx=1, reasoning_only=True
)
assert reasoning_start != -1 and reasoning_end != -1, (
"Could not find reasoning boundaries"
)
decoded_reasoning = tokenizer.decode(input_ids[reasoning_start:reasoning_end])
LOG.debug(f"Reasoning span: {decoded_reasoning!r}")
assert "2+2=5" in decoded_reasoning, (
f"Wrong step not in reasoning span: {decoded_reasoning}"
)
assert "2+2=4" in decoded_reasoning, (
f"Correction not in reasoning span: {decoded_reasoning}"
)
# Verify reasoning labels have both masked and trained tokens
reasoning_labels = labels[reasoning_start:reasoning_end]
reasoning_ids = input_ids[reasoning_start:reasoning_end]
# Decode only the trained tokens — should be exactly the self-correction
trained_ids = [
tid
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
if lab != IGNORE_TOKEN_ID
]
trained_text = tokenizer.decode(trained_ids)
assert trained_text.strip() == "Wait no, 2+2=4.", (
f"Expected trained reasoning to be 'Wait no, 2+2=4.', got: {trained_text!r}"
)
# Decode only the masked tokens — should be exactly the incorrect step
masked_ids = [
tid
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
if lab == IGNORE_TOKEN_ID
]
masked_text = tokenizer.decode(masked_ids)
assert masked_text.strip() == "Hmm maybe 2+2=5.", (
f"Expected masked reasoning to be 'Hmm maybe 2+2=5.', got: {masked_text!r}"
)
# Find content boundaries
content_start, content_end = strategy.find_turn(
turns=turns, turn_idx=1, content_only=True
)
assert content_start != -1 and content_end != -1, (
"Could not find content boundaries"
)
# Content should be fully trained — decode trained tokens to verify
content_labels = labels[content_start:content_end]
content_ids = input_ids[content_start:content_end]
content_trained_ids = [
tid
for tid, lab in zip(content_ids, content_labels, strict=True)
if lab != IGNORE_TOKEN_ID
]
content_trained_text = tokenizer.decode(content_trained_ids)
assert "The answer is 4." in content_trained_text, (
f"Expected 'The answer is 4.' in trained content tokens, "
f"got: {content_trained_text!r}"
)
assert all(label != IGNORE_TOKEN_ID for label in content_labels), (
f"Expected all content labels trained, got {content_labels}"
)
@enable_hf_offline
def test_reasoning_content_as_list_no_training_flags(self, qwen3_tokenizer):
"""Test that reasoning_content as list without training flags still works."""
LOG.info("Testing reasoning_content as list without training flags on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Both as lists, no per-part training flags
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Step 1: addition."},
{"type": "text", "text": " Step 2: 2+2=4."},
],
"content": [
{"type": "text", "text": "The answer is 4."},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
# Should tokenize without errors
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
# Verify the full output contains both reasoning and content
full_text = tokenizer.decode(res["input_ids"])
assert "Step 1: addition." in full_text
assert "Step 2: 2+2=4." in full_text
assert "The answer is 4." in full_text