126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
"""
|
|
Tests for utils in axolotl.utils.chat_templates
|
|
"""
|
|
import unittest
|
|
|
|
import pytest
|
|
from transformers import AutoTokenizer
|
|
|
|
from axolotl.utils.chat_templates import (
|
|
_CHAT_TEMPLATES,
|
|
extract_chat_template_args,
|
|
get_chat_template,
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="llama3_tokenizer")
|
|
def fixture_llama3_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
|
|
|
return tokenizer
|
|
|
|
|
|
class TestGetChatTemplateUtils:
|
|
"""
|
|
Tests the get_chat_template function.
|
|
"""
|
|
|
|
def test_known_chat_template(self):
|
|
chat_template_str = get_chat_template("llama3")
|
|
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
|
|
|
|
def test_invalid_chat_template(self):
|
|
with pytest.raises(ValueError) as exc:
|
|
get_chat_template("invalid_template")
|
|
assert str(exc) == "Template 'invalid_template' not found."
|
|
|
|
def test_tokenizer_default_no_tokenizer(self):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("tokenizer_default", tokenizer=None)
|
|
|
|
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
|
|
|
|
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
|
llama3_tokenizer.chat_template = "test_template"
|
|
chat_template_str = get_chat_template(
|
|
"tokenizer_default", tokenizer=llama3_tokenizer
|
|
)
|
|
assert chat_template_str == "test_template"
|
|
|
|
def test_tokenizer_default_fallback_no_tokenizer(self):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
|
|
|
|
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
|
self, llama3_tokenizer
|
|
):
|
|
chat_template_str = get_chat_template(
|
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
|
)
|
|
assert chat_template_str == get_chat_template("chatml")
|
|
|
|
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
|
self, llama3_tokenizer
|
|
):
|
|
llama3_tokenizer.chat_template = "test_template"
|
|
chat_template_str = get_chat_template(
|
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
|
)
|
|
assert chat_template_str == "test_template"
|
|
|
|
def test_jinja_template_mode(self):
|
|
jinja_template = "example_jinja_template"
|
|
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
|
|
assert chat_template_str == jinja_template
|
|
|
|
def test_jinja_template_mode_no_jinja_template(self):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("jinja", jinja_template=None)
|
|
|
|
def test_extract_chat_template_args(self):
|
|
# No ds_cfg
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={"chat_template": "chatml"},
|
|
)
|
|
assert chat_template_choice == "chatml"
|
|
assert chat_template_jinja is None
|
|
|
|
# ds_cfg provided
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={
|
|
"chat_template": "jinja",
|
|
"chat_template_jinja": "global_jinja_template",
|
|
},
|
|
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
|
|
)
|
|
assert chat_template_choice == "llama3"
|
|
assert chat_template_jinja is None
|
|
|
|
# ds_cfg provided with jinja template
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={"chat_template": "chatml", "chat_template_jinja": None},
|
|
ds_cfg={
|
|
"chat_template": "jinja",
|
|
"chat_template_jinja": "ds_jinja_template",
|
|
},
|
|
)
|
|
assert chat_template_choice == "jinja"
|
|
assert chat_template_jinja == "ds_jinja_template"
|
|
|
|
# ds_cfg provided with no chat_template
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={
|
|
"chat_template": "jinja",
|
|
"chat_template_jinja": "global_jinja_template",
|
|
},
|
|
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
|
|
)
|
|
assert chat_template_choice == "jinja"
|
|
assert chat_template_jinja == "global_jinja_template"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|