Add tests

This commit is contained in:
Chirag Jain
2024-07-12 09:04:59 +05:30
parent 5edaad5b8b
commit 4e38cea6b8
2 changed files with 48 additions and 1 deletions

View File

@@ -59,7 +59,7 @@ def chat_templates(user_choice: str, tokenizer=None):
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}"
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
)
if user_choice in templates:

View File

@@ -83,6 +83,53 @@ def fixture_llama3_tokenizer():
return tokenizer
class TestChatTemplates:
"""
Tests the chat_templates function.
"""
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
chat_templates("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
chat_templates("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
chat_templates("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = chat_templates(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
chat_templates("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = chat_templates(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == chat_templates("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = chat_templates(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
class TestAssistantChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.