Add tests
This commit is contained in:
@@ -59,7 +59,7 @@ def chat_templates(user_choice: str, tokenizer=None):
|
||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||
]
|
||||
LOG.warning(
|
||||
f"No chat template found on tokenizer, falling back to {user_choice}"
|
||||
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
||||
)
|
||||
|
||||
if user_choice in templates:
|
||||
|
||||
@@ -83,6 +83,53 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestChatTemplates:
|
||||
"""
|
||||
Tests the chat_templates function.
|
||||
"""
|
||||
|
||||
def test_invalid_chat_template(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
chat_templates("invalid_template")
|
||||
assert str(exc) == "Template 'invalid_template' not found."
|
||||
|
||||
def test_tokenizer_default_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
chat_templates("tokenizer_default", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
with pytest.raises(ValueError):
|
||||
chat_templates("tokenizer_default", tokenizer=llama3_tokenizer)
|
||||
|
||||
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = chat_templates(
|
||||
"tokenizer_default", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
def test_tokenizer_default_fallback_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
chat_templates("tokenizer_default_fallback_test", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
chat_template_str = chat_templates(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == chat_templates("chatml")
|
||||
|
||||
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = chat_templates(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
|
||||
class TestAssistantChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
|
||||
Reference in New Issue
Block a user