Add option chat_template_jinja to provide a jinja template
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -338,13 +338,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template = ds_cfg.get("chat_template", "chatml")
|
||||
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_str,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
@@ -30,6 +31,12 @@ def default(
|
||||
role_map[source] = target
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
chat_template_string = get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
@@ -51,14 +58,14 @@ def default(
|
||||
result["prompt"] = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
result["chosen"] = tokenizer.apply_chat_template(
|
||||
[chosen],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||
@@ -67,7 +74,7 @@ def default(
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[rejected],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
@@ -28,18 +28,13 @@ def load(
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
"""
|
||||
|
||||
chat_template = chat_templates("chatml")
|
||||
if ds_cfg and "chat_template" in ds_cfg:
|
||||
chat_template = ds_cfg["chat_template"]
|
||||
try:
|
||||
chat_template = chat_templates(chat_template, tokenizer=tokenizer)
|
||||
except ValueError:
|
||||
pass
|
||||
tokenizer.chat_template = chat_template
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
tokenizer.chat_template = chat_template_string
|
||||
|
||||
return ORPOTokenizingStrategy(
|
||||
ORPOPrompter(chat_template, tokenizer),
|
||||
ORPOPrompter(chat_template_string, tokenizer),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -251,25 +246,27 @@ def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-a
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer)
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_str_len = len(res["prompt"])
|
||||
res["chosen"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
res["rejected"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
|
||||
|
||||
@@ -3,14 +3,34 @@ This module provides functionality for selecting chat templates based on user ch
|
||||
These templates are used for formatting messages in a conversation.
|
||||
"""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.chat_templates")
|
||||
|
||||
_JINJA_TEMPALTE_CHOICE = "jinja"
|
||||
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||
|
||||
_TEMPLATES = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
}
|
||||
|
||||
def chat_templates(user_choice: str, tokenizer=None):
|
||||
|
||||
def get_chat_template(
|
||||
user_choice: str,
|
||||
jinja_template: Optional[str] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
):
|
||||
"""
|
||||
Finds the correct chat_template for the tokenizer_config.
|
||||
|
||||
@@ -23,17 +43,12 @@ def chat_templates(user_choice: str, tokenizer=None):
|
||||
Raises:
|
||||
ValueError: If the user_choice is not found in the templates.
|
||||
"""
|
||||
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
}
|
||||
if user_choice == _JINJA_TEMPALTE_CHOICE:
|
||||
if not jinja_template:
|
||||
raise ValueError(
|
||||
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}"
|
||||
)
|
||||
return jinja_template
|
||||
|
||||
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
|
||||
if not tokenizer:
|
||||
@@ -62,7 +77,33 @@ def chat_templates(user_choice: str, tokenizer=None):
|
||||
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
||||
)
|
||||
|
||||
if user_choice in templates:
|
||||
return templates[user_choice]
|
||||
if user_choice in _TEMPLATES:
|
||||
return _TEMPLATES[user_choice]
|
||||
|
||||
raise ValueError(f"Template '{user_choice}' not found.")
|
||||
|
||||
|
||||
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
if ds_cfg and ds_cfg.get("chat_template"):
|
||||
chat_template_choice = ds_cfg.get("chat_template") or "chatml"
|
||||
chat_template_jinja = ds_cfg.get("chat_template_jinja")
|
||||
else:
|
||||
chat_template_choice = cfg.get("chat_template") or "chatml"
|
||||
chat_template_jinja = cfg.get("chat_template_jinja")
|
||||
return chat_template_choice, chat_template_jinja
|
||||
|
||||
|
||||
def get_chat_template_from_config(
|
||||
cfg,
|
||||
ds_cfg: Optional[Dict[str, Any]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
) -> str:
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
return get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
@@ -204,6 +204,7 @@ def normalize_cfg_datasets(cfg):
|
||||
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].chat_template = cfg.chat_template
|
||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
|
||||
@@ -28,6 +28,21 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
inst = "inst" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
mistral = "mistral" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
|
||||
@@ -111,12 +126,15 @@ class SFTDataset(BaseModel):
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
chat_template: Optional[str] = None
|
||||
chat_template: Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
] = ChatTemplate.chatml
|
||||
chat_template_jinja: Optional[str] = None
|
||||
data_files: Optional[Union[str, List[str]]] = None
|
||||
name: Optional[str] = None
|
||||
ds_type: Optional[str] = None
|
||||
train_on_split: Optional[str] = None
|
||||
|
||||
field: Optional[str] = None
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
@@ -127,12 +145,22 @@ class SFTDataset(BaseModel):
|
||||
message_field_training_detail: Optional[str] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
drop_system_message: Optional[bool] = None
|
||||
|
||||
trust_remote_code: Optional[bool] = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
"""User defined typing for DPO"""
|
||||
@@ -186,20 +214,6 @@ class RLType(str, Enum):
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
inst = "inst" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
mistral = "mistral" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
@@ -671,12 +685,11 @@ class AxolotlInputConfig(
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
low_cpu_mem_usage: Optional[bool] = None
|
||||
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
]
|
||||
] = None
|
||||
chat_template: Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
] = ChatTemplate.chatml
|
||||
chat_template_jinja: Optional[str] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
fix_untrained_tokens: Optional[bool] = None
|
||||
@@ -785,6 +798,18 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_wo_flash(cls, data):
|
||||
|
||||
@@ -45,7 +45,7 @@ from axolotl.monkeypatch.multipack import (
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import zero_only
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
@@ -285,7 +285,10 @@ def load_tokenizer(cfg):
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if cfg.chat_template:
|
||||
chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer)
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||
chat_template_string = chat_template_string.replace(
|
||||
"You are a helpful assistant.", cfg.default_system_message
|
||||
|
||||
125
tests/prompt_strategies/test_chat_template_utils.py
Normal file
125
tests/prompt_strategies/test_chat_template_utils.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Tests for utils in axolotl.utils.chat_templates
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.chat_templates import (
|
||||
_TEMPLATES,
|
||||
extract_chat_template_args,
|
||||
get_chat_template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
def fixture_llama3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestGetChatTemplateUtils:
|
||||
"""
|
||||
Tests the get_chat_template function.
|
||||
"""
|
||||
|
||||
def test_known_chat_template(self):
|
||||
chat_template_str = get_chat_template("llama3")
|
||||
assert chat_template_str == _TEMPLATES["llama3"]
|
||||
|
||||
def test_invalid_chat_template(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
get_chat_template("invalid_template")
|
||||
assert str(exc) == "Template 'invalid_template' not found."
|
||||
|
||||
def test_tokenizer_default_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("tokenizer_default", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
|
||||
|
||||
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = get_chat_template(
|
||||
"tokenizer_default", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
def test_tokenizer_default_fallback_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
chat_template_str = get_chat_template(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == get_chat_template("chatml")
|
||||
|
||||
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = get_chat_template(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
def test_jinja_template_mode(self):
|
||||
jinja_template = "example_jinja_template"
|
||||
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
|
||||
assert chat_template_str == jinja_template
|
||||
|
||||
def test_jinja_template_mode_no_jinja_template(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("jinja", jinja_template=None)
|
||||
|
||||
def test_extract_chat_template_args(self):
|
||||
# No ds_cfg
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={"chat_template": "chatml"},
|
||||
)
|
||||
assert chat_template_choice == "chatml"
|
||||
assert chat_template_jinja is None
|
||||
|
||||
# ds_cfg provided
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "global_jinja_template",
|
||||
},
|
||||
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
|
||||
)
|
||||
assert chat_template_choice == "llama3"
|
||||
assert chat_template_jinja is None
|
||||
|
||||
# ds_cfg provided with jinja template
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={"chat_template": "chatml", "chat_template_jinja": None},
|
||||
ds_cfg={
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "ds_jinja_template",
|
||||
},
|
||||
)
|
||||
assert chat_template_choice == "jinja"
|
||||
assert chat_template_jinja == "ds_jinja_template"
|
||||
|
||||
# ds_cfg provided with no chat_template
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "global_jinja_template",
|
||||
},
|
||||
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
|
||||
)
|
||||
assert chat_template_choice == "jinja"
|
||||
assert chat_template_jinja == "global_jinja_template"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -15,7 +15,7 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -80,53 +80,6 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestChatTemplates:
|
||||
"""
|
||||
Tests the chat_templates function.
|
||||
"""
|
||||
|
||||
def test_invalid_chat_template(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
chat_templates("invalid_template")
|
||||
assert str(exc) == "Template 'invalid_template' not found."
|
||||
|
||||
def test_tokenizer_default_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
chat_templates("tokenizer_default", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
with pytest.raises(ValueError):
|
||||
chat_templates("tokenizer_default", tokenizer=llama3_tokenizer)
|
||||
|
||||
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = chat_templates(
|
||||
"tokenizer_default", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
def test_tokenizer_default_fallback_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
chat_templates("tokenizer_default_fallback_test", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
chat_template_str = chat_templates(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == chat_templates("chatml")
|
||||
|
||||
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = chat_templates(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
|
||||
class TestChatTemplateConfigurations:
|
||||
"""
|
||||
Test class for various configurations of ChatTemplateStrategy.
|
||||
@@ -143,7 +96,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
@@ -186,7 +139,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_inputs=False")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -227,7 +180,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing roles_to_train with assistant only")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -253,7 +206,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing roles_to_train with all roles")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
@@ -284,7 +237,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with empty roles_to_train")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -303,7 +256,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='all'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -328,7 +281,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='turn'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -376,7 +329,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='last'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -407,7 +360,7 @@ class TestChatTemplateConfigurations:
|
||||
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing with train_on_eos='none'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
@@ -433,7 +386,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with drop_system_message=True")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
||||
llama3_tokenizer, get_chat_template("llama3"), drop_system_message=True
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -459,7 +412,7 @@ class TestChatTemplateConfigurations:
|
||||
}
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
||||
llama3_tokenizer, get_chat_template("llama3"), roles=custom_roles
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -511,7 +464,7 @@ class TestChatTemplateConfigurations:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
get_chat_template("llama3"),
|
||||
message_field_training="train",
|
||||
message_field_training_detail="train_detail",
|
||||
),
|
||||
@@ -775,7 +728,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
@@ -816,7 +769,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_field_training="training",
|
||||
@@ -873,7 +826,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
@@ -923,7 +876,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
@@ -973,7 +926,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="none",
|
||||
|
||||
Reference in New Issue
Block a user