* Allow using tokenizer's default chat template with fallbacks Summary of changes: 1. Adds `tokenizer_default` as option for `chat_template` in `chat_template` prompt strategy that allows using the chat template from tokenizer's config.json 2. Allows falling back to chat templates available in axolotl if tokenizer does not have a chat template 3. Adds a mistral chat template which supports system message - taken from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja --- Why? Many popular models are not trained with chatml format. As a result for the model to correctly learn chatml we have to turn on train_on_inputs which requires more compute and time. If we can use the model's already learned chat template we can just learn the output tokens --- Todo: - Write tests * Add tests * Fix lint and bug post merge from main * Add option `chat_template_jinja` to provide a jinja template * remove custom mistral template * Address review comments and add docs * Update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * fix: set default to tokenizer template * Merge branch 'main' into cj_tokenizer_default_prompt_template * chore: remove redundant function * fix: re-arrange enum declaration position * fix: refactor artifact left from main merge * feat(doc): updated config with chat template options and clarified examples * chore: clarify doc * chore: added example for non-default template * chore: refactor * fix: test * fix: config being dropped and unittest to catch that * chore: lint * chore: skip duplicate * fix: rename var after merge * feat: add test for levy's dpo case * fix: remove default setting on edge case where chat template overriden in dataset section * feat: handle sharegpt deprecation better in docs * feat: add example using fallback * feat: handles chat_template requiring specific user/assistant order * fix: update test based on new defaults * fix: imported name incorrectly updated on merge * chore: lint * fix: update dummy message to prevent potential overlap with real content * fix(doc): formatting * fix: update bradleyterry to use new chat_template --------- Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
"""
|
|
Tests for utils in axolotl.utils.chat_templates
|
|
"""
|
|
import unittest
|
|
|
|
import pytest
|
|
from transformers import AutoTokenizer
|
|
|
|
from axolotl.utils.chat_templates import (
|
|
_CHAT_TEMPLATES,
|
|
extract_chat_template_args,
|
|
get_chat_template,
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="llama3_tokenizer")
|
|
def fixture_llama3_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
|
|
|
return tokenizer
|
|
|
|
|
|
class TestGetChatTemplateUtils:
|
|
"""
|
|
Tests the get_chat_template function.
|
|
"""
|
|
|
|
def test_known_chat_template(self):
|
|
chat_template_str = get_chat_template("llama3")
|
|
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
|
|
|
|
def test_invalid_chat_template(self):
|
|
with pytest.raises(ValueError) as exc:
|
|
get_chat_template("invalid_template")
|
|
assert str(exc) == "Template 'invalid_template' not found."
|
|
|
|
def test_tokenizer_default_no_tokenizer(self):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("tokenizer_default", tokenizer=None)
|
|
|
|
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
|
|
|
|
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
|
llama3_tokenizer.chat_template = "test_template"
|
|
chat_template_str = get_chat_template(
|
|
"tokenizer_default", tokenizer=llama3_tokenizer
|
|
)
|
|
assert chat_template_str == "test_template"
|
|
|
|
def test_tokenizer_default_fallback_no_tokenizer(self):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
|
|
|
|
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
|
self, llama3_tokenizer
|
|
):
|
|
chat_template_str = get_chat_template(
|
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
|
)
|
|
assert chat_template_str == get_chat_template("chatml")
|
|
|
|
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
|
self, llama3_tokenizer
|
|
):
|
|
llama3_tokenizer.chat_template = "test_template"
|
|
chat_template_str = get_chat_template(
|
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
|
)
|
|
assert chat_template_str == "test_template"
|
|
|
|
def test_jinja_template_mode(self):
|
|
jinja_template = "example_jinja_template"
|
|
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
|
|
assert chat_template_str == jinja_template
|
|
|
|
def test_jinja_template_mode_no_jinja_template(self):
|
|
with pytest.raises(ValueError):
|
|
get_chat_template("jinja", jinja_template=None)
|
|
|
|
def test_extract_chat_template_args(self):
|
|
# No ds_cfg
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={"chat_template": "chatml"},
|
|
)
|
|
assert chat_template_choice == "chatml"
|
|
assert chat_template_jinja is None
|
|
|
|
# ds_cfg provided
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={
|
|
"chat_template": "jinja",
|
|
"chat_template_jinja": "global_jinja_template",
|
|
},
|
|
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
|
|
)
|
|
assert chat_template_choice == "llama3"
|
|
assert chat_template_jinja is None
|
|
|
|
# ds_cfg provided with jinja template
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={"chat_template": "chatml", "chat_template_jinja": None},
|
|
ds_cfg={
|
|
"chat_template": "jinja",
|
|
"chat_template_jinja": "ds_jinja_template",
|
|
},
|
|
)
|
|
assert chat_template_choice == "jinja"
|
|
assert chat_template_jinja == "ds_jinja_template"
|
|
|
|
# ds_cfg provided with no chat_template
|
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
|
cfg={
|
|
"chat_template": "jinja",
|
|
"chat_template_jinja": "global_jinja_template",
|
|
},
|
|
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
|
|
)
|
|
assert chat_template_choice == "jinja"
|
|
assert chat_template_jinja == "global_jinja_template"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|