From eb188acbd4b9c5cb8dd1948f78d79923da7d1dfa Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 01:43:40 +0530 Subject: [PATCH] Add option `chat_template_jinja` to provide a jinja template --- .../prompt_strategies/chat_template.py | 11 +- .../prompt_strategies/dpo/chat_template.py | 19 ++- .../prompt_strategies/orpo/chat_template.py | 27 ++-- src/axolotl/utils/chat_templates.py | 69 ++++++++-- src/axolotl/utils/config/__init__.py | 1 + .../config/models/input/v0_4_1/__init__.py | 73 ++++++---- src/axolotl/utils/models.py | 7 +- .../test_chat_template_utils.py | 125 ++++++++++++++++++ .../prompt_strategies/test_chat_templates.py | 83 +++--------- 9 files changed, 284 insertions(+), 131 deletions(-) create mode 100644 tests/prompt_strategies/test_chat_template_utils.py diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 38460ebac..c79adbd5e 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config # Configure the logger logging.basicConfig(level=logging.DEBUG) @@ -338,13 +338,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ds_cfg = ds_cfg or {} - chat_template = ds_cfg.get("chat_template", "chatml") - chat_template_str = chat_templates(chat_template, tokenizer=tokenizer) - LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---") + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") prompter_params = { "tokenizer": tokenizer, - "chat_template": chat_template_str, + "chat_template": chat_template_string, "message_field_role": ds_cfg.get("message_field_role", "from"), "message_field_content": ds_cfg.get("message_field_content", "value"), "message_field_training": ds_cfg.get("message_field_training", "training"), diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index e0e5eb129..60333b33b 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -2,15 +2,16 @@ DPO prompt strategies for using tokenizer chat templates. """ -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template def default( cfg, dataset_idx=0, **kwargs ): # pylint: disable=possibly-unused-variable,unused-argument ds_cfg = cfg["datasets"][dataset_idx] - chat_template_str = chat_templates(cfg.chat_template) - + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) field_messages = ds_cfg.get("field_messages", "messages") field_chosen = ds_cfg.get("field_chosen", "chosen") field_rejected = ds_cfg.get("field_rejected", "rejected") @@ -30,6 +31,12 @@ def default( role_map[source] = target def transform_fn(sample, tokenizer=None): + chat_template_string = get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + messages = sample[field_messages] messages = [ { @@ -51,14 +58,14 @@ def default( result["prompt"] = tokenizer.apply_chat_template( messages, add_generation_prompt=True, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) result["chosen"] = tokenizer.apply_chat_template( [chosen], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) chosen_strip_index = result["chosen"].find(chosen["content"]) @@ -67,7 +74,7 @@ def default( result["rejected"] = tokenizer.apply_chat_template( [rejected], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) rejected_strip_index = result["rejected"].find(rejected["content"]) diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index f9e328ee6..e53a54748 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompters import Prompter -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config class Message(BaseModel): @@ -28,18 +28,13 @@ def load( """ chatml transforms for datasets with system, input, chosen, rejected """ - - chat_template = chat_templates("chatml") - if ds_cfg and "chat_template" in ds_cfg: - chat_template = ds_cfg["chat_template"] - try: - chat_template = chat_templates(chat_template, tokenizer=tokenizer) - except ValueError: - pass - tokenizer.chat_template = chat_template + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + tokenizer.chat_template = chat_template_string return ORPOTokenizingStrategy( - ORPOPrompter(chat_template, tokenizer), + ORPOPrompter(chat_template_string, tokenizer), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -251,25 +246,27 @@ def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-a def transform_fn(sample, tokenizer=None): res = {} - chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer) + chat_template_string = get_chat_template_from_config( + cfg=cfg, tokenizer=tokenizer + ) res["prompt"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], add_generation_prompt=True, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) prompt_str_len = len(res["prompt"]) res["chosen"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, )[prompt_str_len:] res["rejected"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, )[prompt_str_len:] diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 597eb4185..ac00a82ab 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -3,14 +3,34 @@ This module provides functionality for selecting chat templates based on user ch These templates are used for formatting messages in a conversation. """ import logging +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase LOG = logging.getLogger("axolotl.utils.chat_templates") +_JINJA_TEMPALTE_CHOICE = "jinja" _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" _DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" +_TEMPLATES = { + "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", + "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}", + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", + "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", +} -def chat_templates(user_choice: str, tokenizer=None): + +def get_chat_template( + user_choice: str, + jinja_template: Optional[str] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +): """ Finds the correct chat_template for the tokenizer_config. @@ -23,17 +43,12 @@ def chat_templates(user_choice: str, tokenizer=None): Raises: ValueError: If the user_choice is not found in the templates. """ - - templates = { - "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", - "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}", - "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - } + if user_choice == _JINJA_TEMPALTE_CHOICE: + if not jinja_template: + raise ValueError( + f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}" + ) + return jinja_template if user_choice == _DEFAULT_TEMPLATE_CHOICE: if not tokenizer: @@ -62,7 +77,33 @@ def chat_templates(user_choice: str, tokenizer=None): f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." ) - if user_choice in templates: - return templates[user_choice] + if user_choice in _TEMPLATES: + return _TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") + + +def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None): + if ds_cfg and ds_cfg.get("chat_template"): + chat_template_choice = ds_cfg.get("chat_template") or "chatml" + chat_template_jinja = ds_cfg.get("chat_template_jinja") + else: + chat_template_choice = cfg.get("chat_template") or "chatml" + chat_template_jinja = cfg.get("chat_template_jinja") + return chat_template_choice, chat_template_jinja + + +def get_chat_template_from_config( + cfg, + ds_cfg: Optional[Dict[str, Any]] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +) -> str: + ds_cfg = ds_cfg or {} + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + return get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index ed165e89c..04caf0d5a 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -204,6 +204,7 @@ def normalize_cfg_datasets(cfg): f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template" ) cfg.datasets[idx].chat_template = cfg.chat_template + cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 343e65b55..26c02c7e2 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -28,6 +28,21 @@ LOG = logging.getLogger("axolotl.utils.config.models.input") SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} +class ChatTemplate(str, Enum): + """Chat templates configuration subset""" + + jinja = "jinja" # pylint: disable=invalid-name + alpaca = "alpaca" # pylint: disable=invalid-name + chatml = "chatml" # pylint: disable=invalid-name + inst = "inst" # pylint: disable=invalid-name + gemma = "gemma" # pylint: disable=invalid-name + cohere = "cohere" # pylint: disable=invalid-name + llama3 = "llama3" # pylint: disable=invalid-name + phi_3 = "phi_3" # pylint: disable=invalid-name + mistral = "mistral" # pylint: disable=invalid-name + tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name + + class DeprecatedParameters(BaseModel): """configurations that are deprecated""" @@ -111,12 +126,15 @@ class SFTDataset(BaseModel): type: Optional[Union[str, UserDefinedPrompterType]] = None shards: Optional[int] = None conversation: Optional[str] = None - chat_template: Optional[str] = None + chat_template: Union[ + ChatTemplate, + Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], + ] = ChatTemplate.chatml + chat_template_jinja: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None name: Optional[str] = None ds_type: Optional[str] = None train_on_split: Optional[str] = None - field: Optional[str] = None field_human: Optional[str] = None field_model: Optional[str] = None @@ -127,12 +145,22 @@ class SFTDataset(BaseModel): message_field_training_detail: Optional[str] = None roles_to_train: Optional[List[str]] = None train_on_eos: Optional[str] = None - roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None - trust_remote_code: Optional[bool] = False + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + return data + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" @@ -186,20 +214,6 @@ class RLType(str, Enum): simpo = "simpo" # pylint: disable=invalid-name -class ChatTemplate(str, Enum): - """Chat templates configuration subset""" - - alpaca = "alpaca" # pylint: disable=invalid-name - chatml = "chatml" # pylint: disable=invalid-name - inst = "inst" # pylint: disable=invalid-name - gemma = "gemma" # pylint: disable=invalid-name - cohere = "cohere" # pylint: disable=invalid-name - llama3 = "llama3" # pylint: disable=invalid-name - phi_3 = "phi_3" # pylint: disable=invalid-name - mistral = "mistral" # pylint: disable=invalid-name - tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name - - class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -671,12 +685,11 @@ class AxolotlInputConfig( gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None - chat_template: Optional[ - Union[ - ChatTemplate, - Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], - ] - ] = None + chat_template: Union[ + ChatTemplate, + Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], + ] = ChatTemplate.chatml + chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None fix_untrained_tokens: Optional[bool] = None @@ -785,6 +798,18 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_wo_flash(cls, data): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ca6b1cb5e..52a85c2ac 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -45,7 +45,7 @@ from axolotl.monkeypatch.multipack import ( ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import zero_only from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper @@ -285,7 +285,10 @@ def load_tokenizer(cfg): LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: - chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer) + chat_template_string = get_chat_template_from_config( + cfg=cfg, + tokenizer=tokenizer, + ) if cfg.default_system_message and cfg.chat_template == "chatml": chat_template_string = chat_template_string.replace( "You are a helpful assistant.", cfg.default_system_message diff --git a/tests/prompt_strategies/test_chat_template_utils.py b/tests/prompt_strategies/test_chat_template_utils.py new file mode 100644 index 000000000..e220ed13c --- /dev/null +++ b/tests/prompt_strategies/test_chat_template_utils.py @@ -0,0 +1,125 @@ +""" +Tests for utils in axolotl.utils.chat_templates +""" +import unittest + +import pytest +from transformers import AutoTokenizer + +from axolotl.utils.chat_templates import ( + _TEMPLATES, + extract_chat_template_args, + get_chat_template, +) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + + return tokenizer + + +class TestGetChatTemplateUtils: + """ + Tests the get_chat_template function. + """ + + def test_known_chat_template(self): + chat_template_str = get_chat_template("llama3") + assert chat_template_str == _TEMPLATES["llama3"] + + def test_invalid_chat_template(self): + with pytest.raises(ValueError) as exc: + get_chat_template("invalid_template") + assert str(exc) == "Template 'invalid_template' not found." + + def test_tokenizer_default_no_tokenizer(self): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default", tokenizer=None) + + def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer) + + def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = get_chat_template( + "tokenizer_default", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_tokenizer_default_fallback_no_tokenizer(self): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default_fallback_test", tokenizer=None) + + def test_tokenizer_default_fallback_no_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + chat_template_str = get_chat_template( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == get_chat_template("chatml") + + def test_tokenizer_default_fallback_with_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = get_chat_template( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_jinja_template_mode(self): + jinja_template = "example_jinja_template" + chat_template_str = get_chat_template("jinja", jinja_template=jinja_template) + assert chat_template_str == jinja_template + + def test_jinja_template_mode_no_jinja_template(self): + with pytest.raises(ValueError): + get_chat_template("jinja", jinja_template=None) + + def test_extract_chat_template_args(self): + # No ds_cfg + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={"chat_template": "chatml"}, + ) + assert chat_template_choice == "chatml" + assert chat_template_jinja is None + + # ds_cfg provided + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={ + "chat_template": "jinja", + "chat_template_jinja": "global_jinja_template", + }, + ds_cfg={"chat_template": "llama3", "chat_template_jinja": None}, + ) + assert chat_template_choice == "llama3" + assert chat_template_jinja is None + + # ds_cfg provided with jinja template + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={"chat_template": "chatml", "chat_template_jinja": None}, + ds_cfg={ + "chat_template": "jinja", + "chat_template_jinja": "ds_jinja_template", + }, + ) + assert chat_template_choice == "jinja" + assert chat_template_jinja == "ds_jinja_template" + + # ds_cfg provided with no chat_template + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={ + "chat_template": "jinja", + "chat_template_jinja": "global_jinja_template", + }, + ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"}, + ) + assert chat_template_choice == "jinja" + assert chat_template_jinja == "global_jinja_template" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 97b8792fd..deae5e7eb 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -15,7 +15,7 @@ from axolotl.prompt_strategies.chat_template import ( load, ) from axolotl.prompters import IGNORE_TOKEN_ID -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.dict import DictDefault logging.basicConfig(level=logging.DEBUG) @@ -80,53 +80,6 @@ def fixture_llama3_tokenizer(): return tokenizer -class TestChatTemplates: - """ - Tests the chat_templates function. - """ - - def test_invalid_chat_template(self): - with pytest.raises(ValueError) as exc: - chat_templates("invalid_template") - assert str(exc) == "Template 'invalid_template' not found." - - def test_tokenizer_default_no_tokenizer(self): - with pytest.raises(ValueError): - chat_templates("tokenizer_default", tokenizer=None) - - def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer): - with pytest.raises(ValueError): - chat_templates("tokenizer_default", tokenizer=llama3_tokenizer) - - def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer): - llama3_tokenizer.chat_template = "test_template" - chat_template_str = chat_templates( - "tokenizer_default", tokenizer=llama3_tokenizer - ) - assert chat_template_str == "test_template" - - def test_tokenizer_default_fallback_no_tokenizer(self): - with pytest.raises(ValueError): - chat_templates("tokenizer_default_fallback_test", tokenizer=None) - - def test_tokenizer_default_fallback_no_chat_template_on_tokenizer( - self, llama3_tokenizer - ): - chat_template_str = chat_templates( - "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer - ) - assert chat_template_str == chat_templates("chatml") - - def test_tokenizer_default_fallback_with_chat_template_on_tokenizer( - self, llama3_tokenizer - ): - llama3_tokenizer.chat_template = "test_template" - chat_template_str = chat_templates( - "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer - ) - assert chat_template_str == "test_template" - - class TestChatTemplateConfigurations: """ Test class for various configurations of ChatTemplateStrategy. @@ -143,7 +96,7 @@ class TestChatTemplateConfigurations: def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -186,7 +139,7 @@ class TestChatTemplateConfigurations: def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -227,7 +180,7 @@ class TestChatTemplateConfigurations: def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -253,7 +206,7 @@ class TestChatTemplateConfigurations: def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=True, sequence_len=512, @@ -284,7 +237,7 @@ class TestChatTemplateConfigurations: def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -303,7 +256,7 @@ class TestChatTemplateConfigurations: def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -328,7 +281,7 @@ class TestChatTemplateConfigurations: def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -376,7 +329,7 @@ class TestChatTemplateConfigurations: def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -407,7 +360,7 @@ class TestChatTemplateConfigurations: def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, sequence_len=512, @@ -433,7 +386,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with drop_system_message=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), drop_system_message=True + llama3_tokenizer, get_chat_template("llama3"), drop_system_message=True ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -459,7 +412,7 @@ class TestChatTemplateConfigurations: } strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_templates("llama3"), roles=custom_roles + llama3_tokenizer, get_chat_template("llama3"), roles=custom_roles ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -511,7 +464,7 @@ class TestChatTemplateConfigurations: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + get_chat_template("llama3"), message_field_training="train", message_field_training_detail="train_detail", ), @@ -775,7 +728,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + get_chat_template("llama3"), message_field_role="role", message_field_content="content", roles={ @@ -816,7 +769,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_templates("llama3"), + get_chat_template("llama3"), message_field_role="role", message_field_content="content", message_field_training="training", @@ -873,7 +826,7 @@ class TestSharegptChatTemplateLlama3: def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -923,7 +876,7 @@ class TestSharegptChatTemplateLlama3: def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none", @@ -973,7 +926,7 @@ class TestSharegptChatTemplateLlama3: def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") strategy = ChatTemplateStrategy( - ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")), + ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")), tokenizer=llama3_tokenizer, train_on_inputs=False, train_on_eos="none",