Files
axolotl/tests/prompt_strategies/test_dpo_chat_templates.py
Wing Lian 05f03b541a hf offline decorator for tests to workaround rate limits (#2452) [skip ci]
* hf offline decorator for tests to workaround rate limits

* fail quicker so we can see logs

* try new cache name

* limit files downloaded

* phi mini predownload

* offline decorator for phi tokenizer

* handle meta llama 8b offline too

* make sure to return fixtures if they are wrapped too

* more fixes

* more things offline

* more offline things

* fix the env var

* fix the model name

* handle gemma also

* force reload of modules to recheck offline status

* prefetch mistral too

* use reset_sessions so hub picks up offline mode

* more fixes

* rename so it doesn't seem like a context manager

* fix backoff

* switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline

* include additional dataset

* more fixes

* more fixes

* replace tiny shakespeaere dataset

* skip some tests for now

* use more robust check using snapshot download to determine if a dataset name is on the hub

* typo for skip reason

* use local_files_only

* more fixtures

* remove local only

* use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached

* make sure fixtures aren't offline

improve the offline reset
try bumping version of datasets
reorder reloading and setting
prime a new cache
run the tests now with fresh cache
try with a static cache

* now run all the ci again with hopefully a correct cache

* skip wonky tests for now

* skip wonky tests for now

* handle offline mode for model card creation
2025-03-28 19:20:46 -04:00

226 lines
7.2 KiB
Python

"""
tests for chat_template prompt strategy
"""
import unittest
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
],
"chosen": {
"role": "assistant",
"content": "goodbye",
},
"rejected": {
"role": "assistant",
"content": "party on",
},
}
]
)
@pytest.fixture(name="custom_assistant_dataset")
def fixture_custom_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversation": [
{
"speaker": "human",
"text": "hello",
},
{
"speaker": "agent",
"text": "hello",
},
{
"speaker": "human",
"text": "goodbye",
},
],
"better": {
"speaker": "agent",
"text": "goodbye",
},
"worse": {
"speaker": "agent",
"text": "party on",
},
}
]
)
@pytest.fixture(name="phi3_tokenizer")
@enable_hf_offline
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
return tokenizer
@pytest.fixture(name="gemma_tokenizer")
@enable_hf_offline
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"type": "chat_template",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
"message_field_role": "speaker",
"message_field_content": "text",
"roles": {
"user": ["human"],
"assistant": ["agent"],
"system": ["sys"],
},
}
],
}
)
)
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
class TestAssistantDPOChatTemplatePhi3:
"""
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == (
"<|user|>\nhello<|end|>\n"
+ "<|assistant|>\nhello<|end|>\n"
+ "<|user|>\ngoodbye<|end|>\n"
+ "<|assistant|>\n"
)
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
class TestAssistantDPOChatTemplateGemma:
"""
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
assert result["prompt"] == (
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
+ "<start_of_turn>model\nhello<end_of_turn>\n"
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
+ "<start_of_turn>model\n"
)
assert result["chosen"] == "goodbye<end_of_turn>"
assert result["rejected"] == "party on<end_of_turn>"
if __name__ == "__main__":
unittest.main()