diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index d3e6e4d82..597eb4185 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -59,7 +59,7 @@ def chat_templates(user_choice: str, tokenizer=None): len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : ] LOG.warning( - f"No chat template found on tokenizer, falling back to {user_choice}" + f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." ) if user_choice in templates: diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 7b58a1236..fe278fdec 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -83,6 +83,53 @@ def fixture_llama3_tokenizer(): return tokenizer +class TestChatTemplates: + """ + Tests the chat_templates function. + """ + + def test_invalid_chat_template(self): + with pytest.raises(ValueError) as exc: + chat_templates("invalid_template") + assert str(exc) == "Template 'invalid_template' not found." + + def test_tokenizer_default_no_tokenizer(self): + with pytest.raises(ValueError): + chat_templates("tokenizer_default", tokenizer=None) + + def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer): + with pytest.raises(ValueError): + chat_templates("tokenizer_default", tokenizer=llama3_tokenizer) + + def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = chat_templates( + "tokenizer_default", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_tokenizer_default_fallback_no_tokenizer(self): + with pytest.raises(ValueError): + chat_templates("tokenizer_default_fallback_test", tokenizer=None) + + def test_tokenizer_default_fallback_no_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + chat_template_str = chat_templates( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == chat_templates("chatml") + + def test_tokenizer_default_fallback_with_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = chat_templates( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + class TestAssistantChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.