Feat: Add support for tokenizer’s or custom jinja chat_template (#1970)

* Allow using tokenizer's default chat template with fallbacks

Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests

* Add tests

* Fix lint and bug post merge from main

* Add option `chat_template_jinja` to provide a jinja template

* remove custom mistral template

* Address review comments and add docs

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* fix: set default to tokenizer template

* Merge branch 'main' into cj_tokenizer_default_prompt_template

* chore: remove redundant function

* fix: re-arrange enum declaration position

* fix: refactor artifact left from main merge

* feat(doc): updated config with chat template options and clarified examples

* chore: clarify doc

* chore: added example for non-default template

* chore: refactor

* fix: test

* fix: config being dropped and unittest to catch that

* chore: lint

* chore: skip duplicate

* fix: rename var after merge

* feat: add test for levy's dpo case

* fix: remove default setting on edge case where chat template overriden in dataset section

* feat: handle sharegpt deprecation better in docs

* feat: add example using fallback

* feat: handles chat_template requiring specific user/assistant order

* fix: update test based on new defaults

* fix: imported name incorrectly updated on merge

* chore: lint

* fix: update dummy message to prevent potential overlap with real content

* fix(doc): formatting

* fix: update bradleyterry to use new chat_template

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
This commit is contained in:
NanoCode012
2024-10-29 10:14:51 +07:00
committed by GitHub
parent e1e0556c99
commit bfc77b0f36
20 changed files with 900 additions and 118 deletions

View File

@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -63,7 +63,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -1594,7 +1594,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
)

View File

@@ -6,7 +6,7 @@ import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,13 +2,18 @@
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
@@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(

View File

@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -405,10 +405,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),

View File

@@ -2,15 +2,16 @@
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -30,6 +31,12 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages]
messages = [
{
@@ -46,28 +53,29 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
[dummy_user_message, chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
[dummy_user_message, rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
class Message(BaseModel):
@@ -28,18 +28,13 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
tokenizer.chat_template = chat_template_string
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template, tokenizer),
ORPOPrompter(chat_template_string, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -248,28 +243,30 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_str,
chat_template=chat_template_string,
tokenize=False,
)[prompt_str_len:]

View File

@@ -62,7 +62,7 @@ def build_loader(
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
)
conversation = (
ds_cfg["conversation"]

View File

@@ -2,8 +2,19 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional
CHAT_TEMPLATES = {
if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
LOG = logging.getLogger("axolotl.utils.chat_templates")
_JINJA_TEMPALTE_CHOICE = "jinja"
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
_CHAT_TEMPLATES = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
@@ -21,12 +32,18 @@ CHAT_TEMPLATES = {
}
def chat_templates(user_choice: str):
def get_chat_template(
user_choice: str,
jinja_template: Optional[str] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
):
"""
Finds the correct chat_template for the tokenizer_config.
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
Args:
user_choice (str): The user's choice of template.
jinja_template (Optional[str], optional): The jinja template string. Defaults to None.
tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None.
Returns:
str: The chosen template string.
@@ -34,13 +51,71 @@ def chat_templates(user_choice: str):
Raises:
ValueError: If the user_choice is not found in the templates.
"""
if user_choice == _JINJA_TEMPALTE_CHOICE:
if not jinja_template:
raise ValueError(
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}"
)
return jinja_template
if user_choice in CHAT_TEMPLATES:
return CHAT_TEMPLATES[user_choice]
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
)
if user_choice in _CHAT_TEMPLATES:
return _CHAT_TEMPLATES[user_choice]
raise ValueError(f"Template '{user_choice}' not found.")
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
if ds_cfg and ds_cfg.get("chat_template"):
chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = ds_cfg.get("chat_template_jinja")
else:
chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = cfg.get("chat_template_jinja")
return chat_template_choice, chat_template_jinja
def get_chat_template_from_config(
cfg,
ds_cfg: Optional[Dict[str, Any]] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
return get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
def register_chat_template(template_name: str, chat_template: str):
"""
Registers chat templates.
@@ -50,7 +125,7 @@ def register_chat_template(template_name: str, chat_template: str):
chat_template (str): The template string.
"""
if template_name in CHAT_TEMPLATES:
if template_name in _CHAT_TEMPLATES:
raise ValueError(f"Template '{template_name}' already exists.")
CHAT_TEMPLATES[template_name] = chat_template
_CHAT_TEMPLATES[template_name] = chat_template

View File

@@ -228,6 +228,7 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):

View File

@@ -8,9 +8,16 @@ import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -21,6 +28,37 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -105,13 +143,19 @@ class SFTDataset(BaseModel):
input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
chat_template: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -122,13 +166,32 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -174,35 +237,6 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -719,7 +753,13 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[ChatTemplate] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
@@ -828,6 +868,23 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):

View File

@@ -53,7 +53,7 @@ from axolotl.monkeypatch.multipack import (
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
@@ -296,7 +296,10 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = chat_templates(cfg.chat_template)
chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message