* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e] - Add 'tool' to default role_map_inv in dpo/chat_template.py default() and argilla_chat() so datasets with tool-call messages no longer raise KeyError: 'tool' (closes #3217) - Fix generate_dataset_hash_from_config to use canonical tokenizer config + overrides content instead of tokenizer.name_or_path when added_tokens_overrides is set, preventing cache busting when only output_dir changes (closes #3303) - Add three Pydantic config validators to AxolotlConfigWCapabilities: * save_strategy: 'best' requires metric_for_best_model * streaming=True is incompatible with val_set_size > 0 * lora_target_modules list entries must be valid Python regex patterns - Tests for all three changes * review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash * chore: lint * move the validators out of the w/ capabilities schema --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
382 lines
12 KiB
Python
382 lines
12 KiB
Python
"""
|
|
tests for chat_template prompt strategy
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import pytest
|
|
from datasets import Dataset
|
|
from transformers import AutoTokenizer
|
|
|
|
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
from tests.hf_offline_utils import enable_hf_offline
|
|
|
|
|
|
@pytest.fixture(name="assistant_dataset")
|
|
def fixture_assistant_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "hello",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "hello",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "goodbye",
|
|
},
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "goodbye",
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "party on",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="custom_assistant_dataset")
|
|
def fixture_custom_assistant_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"conversation": [
|
|
{
|
|
"speaker": "human",
|
|
"text": "hello",
|
|
},
|
|
{
|
|
"speaker": "agent",
|
|
"text": "hello",
|
|
},
|
|
{
|
|
"speaker": "human",
|
|
"text": "goodbye",
|
|
},
|
|
],
|
|
"better": {
|
|
"speaker": "agent",
|
|
"text": "goodbye",
|
|
},
|
|
"worse": {
|
|
"speaker": "agent",
|
|
"text": "party on",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="argilla_chat_dataset")
|
|
def fixture_argilla_chat_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"chosen": [
|
|
{
|
|
"role": "user",
|
|
"content": "hello",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "goodbye",
|
|
},
|
|
],
|
|
"rejected": [
|
|
{
|
|
"role": "user",
|
|
"content": "hello",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "party on",
|
|
},
|
|
],
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="phi3_tokenizer")
|
|
@enable_hf_offline
|
|
def fixture_phi3_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
|
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(name="gemma_tokenizer")
|
|
@enable_hf_offline
|
|
def fixture_gemma_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
|
|
|
return tokenizer
|
|
|
|
|
|
class TestAssistantDPOChatTemplateLlama3:
|
|
"""
|
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
|
"""
|
|
|
|
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
|
transform_fn, _ = default(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
|
|
assert result["prompt"] == (
|
|
"<|begin_of_text|>"
|
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
)
|
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
|
assert result["rejected"] == "party on<|eot_id|>"
|
|
|
|
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
|
transform_fn, _ = default(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template",
|
|
"field_messages": "conversation",
|
|
"field_chosen": "better",
|
|
"field_rejected": "worse",
|
|
"message_field_role": "speaker",
|
|
"message_field_content": "text",
|
|
"roles": {
|
|
"user": ["human"],
|
|
"assistant": ["agent"],
|
|
"system": ["sys"],
|
|
},
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
|
|
assert result["prompt"] == (
|
|
"<|begin_of_text|>"
|
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
)
|
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
|
assert result["rejected"] == "party on<|eot_id|>"
|
|
|
|
|
|
class TestAssistantDPOChatTemplatePhi3:
|
|
"""
|
|
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
|
|
"""
|
|
|
|
@pytest.mark.xfail(reason="likely upstream issue from v5.4.0")
|
|
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
|
transform_fn, _ = default(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "tokenizer_default",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
|
|
assert result["prompt"] == (
|
|
"<|user|>\nhello<|end|>\n"
|
|
+ "<|assistant|>\nhello<|end|>\n"
|
|
+ "<|user|>\ngoodbye<|end|>\n"
|
|
+ "<|assistant|>\n"
|
|
)
|
|
assert result["chosen"] == "goodbye<|end|>"
|
|
assert result["rejected"] == "party on<|end|>"
|
|
|
|
|
|
class TestAssistantDPOChatTemplateGemma:
|
|
"""
|
|
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
|
|
"""
|
|
|
|
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
|
transform_fn, _ = default(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "tokenizer_default",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
|
|
assert result["prompt"] == (
|
|
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
|
|
+ "<start_of_turn>model\nhello<end_of_turn>\n"
|
|
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
|
|
+ "<start_of_turn>model\n"
|
|
)
|
|
assert result["chosen"] == "goodbye<end_of_turn>"
|
|
assert result["rejected"] == "party on<end_of_turn>"
|
|
|
|
|
|
class TestArgillaChatDPOChatTemplate:
|
|
"""
|
|
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
|
|
"""
|
|
|
|
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
|
|
transform_fn, _ = argilla_chat(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template.argilla_chat",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
|
|
assert result["prompt"] == (
|
|
"<|begin_of_text|>"
|
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
)
|
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
|
assert result["rejected"] == "party on<|eot_id|>"
|
|
|
|
@pytest.mark.xfail(reason="likely upstream issue from v5.4.0")
|
|
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
|
|
transform_fn, _ = argilla_chat(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "tokenizer_default",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template.argilla_chat",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
|
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
|
assert result["chosen"] == "goodbye<|end|>"
|
|
assert result["rejected"] == "party on<|end|>"
|
|
|
|
|
|
class TestDPOChatTemplateToolRole:
|
|
"""
|
|
Test that DPO chat template strategy handles tool role messages without KeyError.
|
|
Regression test for https://github.com/axolotl-ai-cloud/axolotl/issues/3217
|
|
"""
|
|
|
|
def test_tool_role_default_no_key_error(self, llama3_tokenizer):
|
|
"""Messages list with a 'tool' role should not raise KeyError."""
|
|
dataset = Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": [
|
|
{"role": "user", "content": "What is the weather?"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Let me check.",
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"content": "22°C, sunny.",
|
|
},
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "It is 22°C and sunny.",
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "I don't know.",
|
|
},
|
|
}
|
|
]
|
|
)
|
|
transform_fn, _ = default(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"datasets": [{"type": "chat_template"}],
|
|
}
|
|
)
|
|
)
|
|
# Should not raise KeyError: 'tool'
|
|
result = transform_fn(dataset[0], tokenizer=llama3_tokenizer)
|
|
assert "prompt" in result
|
|
assert "chosen" in result
|
|
assert "rejected" in result
|
|
|
|
def test_tool_role_custom_mapping_preserved(self, llama3_tokenizer):
|
|
"""A user-supplied roles mapping that overrides 'tool' is still respected."""
|
|
dataset = Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "tool_result", "content": "42"},
|
|
],
|
|
"chosen": {"role": "assistant", "content": "The answer is 42."},
|
|
"rejected": {"role": "assistant", "content": "Unknown."},
|
|
}
|
|
]
|
|
)
|
|
transform_fn, _ = default(
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"datasets": [
|
|
{
|
|
"type": "chat_template",
|
|
"roles": {
|
|
"user": ["user"],
|
|
"assistant": ["assistant"],
|
|
"system": ["system"],
|
|
"tool": ["tool_result"],
|
|
},
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
result = transform_fn(dataset[0], tokenizer=llama3_tokenizer)
|
|
assert "prompt" in result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|