Add option chat_template_jinja to provide a jinja template

This commit is contained in:
Chirag Jain
2024-07-31 01:43:40 +05:30
parent 34ea51dcf3
commit eb188acbd4
9 changed files with 284 additions and 131 deletions

View File

@@ -0,0 +1,125 @@
"""
Tests for utils in axolotl.utils.chat_templates
"""
import unittest
import pytest
from transformers import AutoTokenizer
from axolotl.utils.chat_templates import (
_TEMPLATES,
extract_chat_template_args,
get_chat_template,
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
return tokenizer
class TestGetChatTemplateUtils:
"""
Tests the get_chat_template function.
"""
def test_known_chat_template(self):
chat_template_str = get_chat_template("llama3")
assert chat_template_str == _TEMPLATES["llama3"]
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
get_chat_template("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == get_chat_template("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_jinja_template_mode(self):
jinja_template = "example_jinja_template"
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
assert chat_template_str == jinja_template
def test_jinja_template_mode_no_jinja_template(self):
with pytest.raises(ValueError):
get_chat_template("jinja", jinja_template=None)
def test_extract_chat_template_args(self):
# No ds_cfg
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml"},
)
assert chat_template_choice == "chatml"
assert chat_template_jinja is None
# ds_cfg provided
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
)
assert chat_template_choice == "llama3"
assert chat_template_jinja is None
# ds_cfg provided with jinja template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml", "chat_template_jinja": None},
ds_cfg={
"chat_template": "jinja",
"chat_template_jinja": "ds_jinja_template",
},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "ds_jinja_template"
# ds_cfg provided with no chat_template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "global_jinja_template"
if __name__ == "__main__":
unittest.main()

View File

@@ -15,7 +15,7 @@ from axolotl.prompt_strategies.chat_template import (
load,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
@@ -80,53 +80,6 @@ def fixture_llama3_tokenizer():
return tokenizer
class TestChatTemplates:
"""
Tests the chat_templates function.
"""
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
chat_templates("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
chat_templates("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
chat_templates("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = chat_templates(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
chat_templates("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = chat_templates(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == chat_templates("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = chat_templates(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
class TestChatTemplateConfigurations:
"""
Test class for various configurations of ChatTemplateStrategy.
@@ -143,7 +96,7 @@ class TestChatTemplateConfigurations:
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
@@ -186,7 +139,7 @@ class TestChatTemplateConfigurations:
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -227,7 +180,7 @@ class TestChatTemplateConfigurations:
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -253,7 +206,7 @@ class TestChatTemplateConfigurations:
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
sequence_len=512,
@@ -284,7 +237,7 @@ class TestChatTemplateConfigurations:
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -303,7 +256,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -328,7 +281,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -376,7 +329,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -407,7 +360,7 @@ class TestChatTemplateConfigurations:
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
sequence_len=512,
@@ -433,7 +386,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with drop_system_message=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
llama3_tokenizer, get_chat_template("llama3"), drop_system_message=True
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -459,7 +412,7 @@ class TestChatTemplateConfigurations:
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
llama3_tokenizer, get_chat_template("llama3"), roles=custom_roles
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -511,7 +464,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
get_chat_template("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),
@@ -775,7 +728,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
roles={
@@ -816,7 +769,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_templates("llama3"),
get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
@@ -873,7 +826,7 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
@@ -923,7 +876,7 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",
@@ -973,7 +926,7 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
train_on_eos="none",