Adds chat templates (#1022)
This commit is contained in:
@@ -589,6 +589,9 @@ datasets:
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
# Currently supports chatml and inst (mistral/mixtral)
|
||||
chat_template: chatml
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
|
||||
29
src/axolotl/utils/chat_templates.py
Normal file
29
src/axolotl/utils/chat_templates.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
This module provides functionality for selecting chat templates based on user choices.
|
||||
These templates are used for formatting messages in a conversation.
|
||||
"""
|
||||
|
||||
|
||||
def chat_templates(user_choice: str):
|
||||
"""
|
||||
Finds the correct chat_template for the tokenizer_config.
|
||||
|
||||
Args:
|
||||
user_choice (str): The user's choice of template.
|
||||
|
||||
Returns:
|
||||
str: The chosen template string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user_choice is not found in the templates.
|
||||
"""
|
||||
|
||||
templates = {
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
}
|
||||
|
||||
if user_choice in templates:
|
||||
return templates[user_choice]
|
||||
|
||||
raise ValueError(f"Template '{user_choice}' not found.")
|
||||
@@ -26,6 +26,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -186,6 +187,12 @@ def load_tokenizer(cfg):
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if cfg.chat_template:
|
||||
tokenizer.chat_template = chat_templates(cfg.chat_template)
|
||||
else:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user