Add tests
This commit is contained in:
@@ -59,7 +59,7 @@ def chat_templates(user_choice: str, tokenizer=None):
|
|||||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||||
]
|
]
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"No chat template found on tokenizer, falling back to {user_choice}"
|
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in templates:
|
||||||
|
|||||||
@@ -83,6 +83,53 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplates:
|
||||||
|
"""
|
||||||
|
Tests the chat_templates function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_invalid_chat_template(self):
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
chat_templates("invalid_template")
|
||||||
|
assert str(exc) == "Template 'invalid_template' not found."
|
||||||
|
|
||||||
|
def test_tokenizer_default_no_tokenizer(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat_templates("tokenizer_default", tokenizer=None)
|
||||||
|
|
||||||
|
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat_templates("tokenizer_default", tokenizer=llama3_tokenizer)
|
||||||
|
|
||||||
|
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||||
|
llama3_tokenizer.chat_template = "test_template"
|
||||||
|
chat_template_str = chat_templates(
|
||||||
|
"tokenizer_default", tokenizer=llama3_tokenizer
|
||||||
|
)
|
||||||
|
assert chat_template_str == "test_template"
|
||||||
|
|
||||||
|
def test_tokenizer_default_fallback_no_tokenizer(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
chat_templates("tokenizer_default_fallback_test", tokenizer=None)
|
||||||
|
|
||||||
|
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
||||||
|
self, llama3_tokenizer
|
||||||
|
):
|
||||||
|
chat_template_str = chat_templates(
|
||||||
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||||
|
)
|
||||||
|
assert chat_template_str == chat_templates("chatml")
|
||||||
|
|
||||||
|
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
||||||
|
self, llama3_tokenizer
|
||||||
|
):
|
||||||
|
llama3_tokenizer.chat_template = "test_template"
|
||||||
|
chat_template_str = chat_templates(
|
||||||
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||||
|
)
|
||||||
|
assert chat_template_str == "test_template"
|
||||||
|
|
||||||
|
|
||||||
class TestAssistantChatTemplateLlama3:
|
class TestAssistantChatTemplateLlama3:
|
||||||
"""
|
"""
|
||||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
|
|||||||
Reference in New Issue
Block a user