Add option chat_template_jinja to provide a jinja template
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -338,13 +338,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
ds_cfg = ds_cfg or {}
|
ds_cfg = ds_cfg or {}
|
||||||
chat_template = ds_cfg.get("chat_template", "chatml")
|
chat_template_string = get_chat_template_from_config(
|
||||||
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
|
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||||
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
|
)
|
||||||
|
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||||
|
|
||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": chat_template_str,
|
"chat_template": chat_template_string,
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||||
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
||||||
|
|||||||
@@ -2,15 +2,16 @@
|
|||||||
DPO prompt strategies for using tokenizer chat templates.
|
DPO prompt strategies for using tokenizer chat templates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
cfg, dataset_idx=0, **kwargs
|
cfg, dataset_idx=0, **kwargs
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
ds_cfg = cfg["datasets"][dataset_idx]
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
chat_template_str = chat_templates(cfg.chat_template)
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg=cfg, ds_cfg=ds_cfg
|
||||||
|
)
|
||||||
field_messages = ds_cfg.get("field_messages", "messages")
|
field_messages = ds_cfg.get("field_messages", "messages")
|
||||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
@@ -30,6 +31,12 @@ def default(
|
|||||||
role_map[source] = target
|
role_map[source] = target
|
||||||
|
|
||||||
def transform_fn(sample, tokenizer=None):
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
chat_template_string = get_chat_template(
|
||||||
|
user_choice=chat_template_choice,
|
||||||
|
jinja_template=chat_template_jinja,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
messages = sample[field_messages]
|
messages = sample[field_messages]
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -51,14 +58,14 @@ def default(
|
|||||||
result["prompt"] = tokenizer.apply_chat_template(
|
result["prompt"] = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
chat_template=chat_template_str,
|
chat_template=chat_template_string,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
result["chosen"] = tokenizer.apply_chat_template(
|
result["chosen"] = tokenizer.apply_chat_template(
|
||||||
[chosen],
|
[chosen],
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
chat_template=chat_template_str,
|
chat_template=chat_template_string,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||||
@@ -67,7 +74,7 @@ def default(
|
|||||||
result["rejected"] = tokenizer.apply_chat_template(
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
[rejected],
|
[rejected],
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
chat_template=chat_template_str,
|
chat_template=chat_template_string,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
@@ -28,18 +28,13 @@ def load(
|
|||||||
"""
|
"""
|
||||||
chatml transforms for datasets with system, input, chosen, rejected
|
chatml transforms for datasets with system, input, chosen, rejected
|
||||||
"""
|
"""
|
||||||
|
chat_template_string = get_chat_template_from_config(
|
||||||
chat_template = chat_templates("chatml")
|
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||||
if ds_cfg and "chat_template" in ds_cfg:
|
)
|
||||||
chat_template = ds_cfg["chat_template"]
|
tokenizer.chat_template = chat_template_string
|
||||||
try:
|
|
||||||
chat_template = chat_templates(chat_template, tokenizer=tokenizer)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
tokenizer.chat_template = chat_template
|
|
||||||
|
|
||||||
return ORPOTokenizingStrategy(
|
return ORPOTokenizingStrategy(
|
||||||
ORPOPrompter(chat_template, tokenizer),
|
ORPOPrompter(chat_template_string, tokenizer),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
@@ -251,25 +246,27 @@ def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-a
|
|||||||
def transform_fn(sample, tokenizer=None):
|
def transform_fn(sample, tokenizer=None):
|
||||||
res = {}
|
res = {}
|
||||||
|
|
||||||
chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer)
|
chat_template_string = get_chat_template_from_config(
|
||||||
|
cfg=cfg, tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
res["prompt"] = tokenizer.apply_chat_template(
|
res["prompt"] = tokenizer.apply_chat_template(
|
||||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
chat_template=chat_template_str,
|
chat_template=chat_template_string,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
prompt_str_len = len(res["prompt"])
|
prompt_str_len = len(res["prompt"])
|
||||||
res["chosen"] = tokenizer.apply_chat_template(
|
res["chosen"] = tokenizer.apply_chat_template(
|
||||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
chat_template=chat_template_str,
|
chat_template=chat_template_string,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)[prompt_str_len:]
|
)[prompt_str_len:]
|
||||||
res["rejected"] = tokenizer.apply_chat_template(
|
res["rejected"] = tokenizer.apply_chat_template(
|
||||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
chat_template=chat_template_str,
|
chat_template=chat_template_string,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)[prompt_str_len:]
|
)[prompt_str_len:]
|
||||||
|
|
||||||
|
|||||||
@@ -3,14 +3,34 @@ This module provides functionality for selecting chat templates based on user ch
|
|||||||
These templates are used for formatting messages in a conversation.
|
These templates are used for formatting messages in a conversation.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.chat_templates")
|
LOG = logging.getLogger("axolotl.utils.chat_templates")
|
||||||
|
|
||||||
|
_JINJA_TEMPALTE_CHOICE = "jinja"
|
||||||
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||||
|
|
||||||
|
_TEMPLATES = {
|
||||||
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
|
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
|
||||||
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
|
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||||
|
}
|
||||||
|
|
||||||
def chat_templates(user_choice: str, tokenizer=None):
|
|
||||||
|
def get_chat_template(
|
||||||
|
user_choice: str,
|
||||||
|
jinja_template: Optional[str] = None,
|
||||||
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Finds the correct chat_template for the tokenizer_config.
|
Finds the correct chat_template for the tokenizer_config.
|
||||||
|
|
||||||
@@ -23,17 +43,12 @@ def chat_templates(user_choice: str, tokenizer=None):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the user_choice is not found in the templates.
|
ValueError: If the user_choice is not found in the templates.
|
||||||
"""
|
"""
|
||||||
|
if user_choice == _JINJA_TEMPALTE_CHOICE:
|
||||||
templates = {
|
if not jinja_template:
|
||||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
raise ValueError(
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}"
|
||||||
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
|
)
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
return jinja_template
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
|
||||||
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
|
||||||
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
|
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
@@ -62,7 +77,33 @@ def chat_templates(user_choice: str, tokenizer=None):
|
|||||||
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in _TEMPLATES:
|
||||||
return templates[user_choice]
|
return _TEMPLATES[user_choice]
|
||||||
|
|
||||||
raise ValueError(f"Template '{user_choice}' not found.")
|
raise ValueError(f"Template '{user_choice}' not found.")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
if ds_cfg and ds_cfg.get("chat_template"):
|
||||||
|
chat_template_choice = ds_cfg.get("chat_template") or "chatml"
|
||||||
|
chat_template_jinja = ds_cfg.get("chat_template_jinja")
|
||||||
|
else:
|
||||||
|
chat_template_choice = cfg.get("chat_template") or "chatml"
|
||||||
|
chat_template_jinja = cfg.get("chat_template_jinja")
|
||||||
|
return chat_template_choice, chat_template_jinja
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_template_from_config(
|
||||||
|
cfg,
|
||||||
|
ds_cfg: Optional[Dict[str, Any]] = None,
|
||||||
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
|
) -> str:
|
||||||
|
ds_cfg = ds_cfg or {}
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg=cfg, ds_cfg=ds_cfg
|
||||||
|
)
|
||||||
|
return get_chat_template(
|
||||||
|
user_choice=chat_template_choice,
|
||||||
|
jinja_template=chat_template_jinja,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ def normalize_cfg_datasets(cfg):
|
|||||||
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
|
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
|
||||||
)
|
)
|
||||||
cfg.datasets[idx].chat_template = cfg.chat_template
|
cfg.datasets[idx].chat_template = cfg.chat_template
|
||||||
|
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||||
|
|||||||
@@ -28,6 +28,21 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
|
|||||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplate(str, Enum):
|
||||||
|
"""Chat templates configuration subset"""
|
||||||
|
|
||||||
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
|
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||||
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
|
inst = "inst" # pylint: disable=invalid-name
|
||||||
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
mistral = "mistral" # pylint: disable=invalid-name
|
||||||
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
"""configurations that are deprecated"""
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
@@ -111,12 +126,15 @@ class SFTDataset(BaseModel):
|
|||||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||||
shards: Optional[int] = None
|
shards: Optional[int] = None
|
||||||
conversation: Optional[str] = None
|
conversation: Optional[str] = None
|
||||||
chat_template: Optional[str] = None
|
chat_template: Union[
|
||||||
|
ChatTemplate,
|
||||||
|
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||||
|
] = ChatTemplate.chatml
|
||||||
|
chat_template_jinja: Optional[str] = None
|
||||||
data_files: Optional[Union[str, List[str]]] = None
|
data_files: Optional[Union[str, List[str]]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
ds_type: Optional[str] = None
|
ds_type: Optional[str] = None
|
||||||
train_on_split: Optional[str] = None
|
train_on_split: Optional[str] = None
|
||||||
|
|
||||||
field: Optional[str] = None
|
field: Optional[str] = None
|
||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
@@ -127,12 +145,22 @@ class SFTDataset(BaseModel):
|
|||||||
message_field_training_detail: Optional[str] = None
|
message_field_training_detail: Optional[str] = None
|
||||||
roles_to_train: Optional[List[str]] = None
|
roles_to_train: Optional[List[str]] = None
|
||||||
train_on_eos: Optional[str] = None
|
train_on_eos: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
drop_system_message: Optional[bool] = None
|
drop_system_message: Optional[bool] = None
|
||||||
|
|
||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_chat_template_config(cls, data):
|
||||||
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
"""User defined typing for DPO"""
|
"""User defined typing for DPO"""
|
||||||
@@ -186,20 +214,6 @@ class RLType(str, Enum):
|
|||||||
simpo = "simpo" # pylint: disable=invalid-name
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
|
||||||
"""Chat templates configuration subset"""
|
|
||||||
|
|
||||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
|
||||||
chatml = "chatml" # pylint: disable=invalid-name
|
|
||||||
inst = "inst" # pylint: disable=invalid-name
|
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
|
||||||
mistral = "mistral" # pylint: disable=invalid-name
|
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
"""LoftQ configuration subset"""
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
@@ -671,12 +685,11 @@ class AxolotlInputConfig(
|
|||||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||||
low_cpu_mem_usage: Optional[bool] = None
|
low_cpu_mem_usage: Optional[bool] = None
|
||||||
|
|
||||||
chat_template: Optional[
|
chat_template: Union[
|
||||||
Union[
|
ChatTemplate,
|
||||||
ChatTemplate,
|
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
] = ChatTemplate.chatml
|
||||||
]
|
chat_template_jinja: Optional[str] = None
|
||||||
] = None
|
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
fix_untrained_tokens: Optional[bool] = None
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
@@ -785,6 +798,18 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_chat_template_config(cls, data):
|
||||||
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_wo_flash(cls, data):
|
def check_sample_packing_wo_flash(cls, data):
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from axolotl.monkeypatch.multipack import (
|
|||||||
)
|
)
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import zero_only
|
from axolotl.utils.distributed import zero_only
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||||
@@ -285,7 +285,10 @@ def load_tokenizer(cfg):
|
|||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if cfg.chat_template:
|
if cfg.chat_template:
|
||||||
chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer)
|
chat_template_string = get_chat_template_from_config(
|
||||||
|
cfg=cfg,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
if cfg.default_system_message and cfg.chat_template == "chatml":
|
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||||
chat_template_string = chat_template_string.replace(
|
chat_template_string = chat_template_string.replace(
|
||||||
"You are a helpful assistant.", cfg.default_system_message
|
"You are a helpful assistant.", cfg.default_system_message
|
||||||
|
|||||||
125
tests/prompt_strategies/test_chat_template_utils.py
Normal file
125
tests/prompt_strategies/test_chat_template_utils.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
Tests for utils in axolotl.utils.chat_templates
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.chat_templates import (
|
||||||
|
_TEMPLATES,
|
||||||
|
extract_chat_template_args,
|
||||||
|
get_chat_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetChatTemplateUtils:
|
||||||
|
"""
|
||||||
|
Tests the get_chat_template function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_known_chat_template(self):
|
||||||
|
chat_template_str = get_chat_template("llama3")
|
||||||
|
assert chat_template_str == _TEMPLATES["llama3"]
|
||||||
|
|
||||||
|
def test_invalid_chat_template(self):
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
get_chat_template("invalid_template")
|
||||||
|
assert str(exc) == "Template 'invalid_template' not found."
|
||||||
|
|
||||||
|
def test_tokenizer_default_no_tokenizer(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_chat_template("tokenizer_default", tokenizer=None)
|
||||||
|
|
||||||
|
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
|
||||||
|
|
||||||
|
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||||
|
llama3_tokenizer.chat_template = "test_template"
|
||||||
|
chat_template_str = get_chat_template(
|
||||||
|
"tokenizer_default", tokenizer=llama3_tokenizer
|
||||||
|
)
|
||||||
|
assert chat_template_str == "test_template"
|
||||||
|
|
||||||
|
def test_tokenizer_default_fallback_no_tokenizer(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
|
||||||
|
|
||||||
|
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
||||||
|
self, llama3_tokenizer
|
||||||
|
):
|
||||||
|
chat_template_str = get_chat_template(
|
||||||
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||||
|
)
|
||||||
|
assert chat_template_str == get_chat_template("chatml")
|
||||||
|
|
||||||
|
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
||||||
|
self, llama3_tokenizer
|
||||||
|
):
|
||||||
|
llama3_tokenizer.chat_template = "test_template"
|
||||||
|
chat_template_str = get_chat_template(
|
||||||
|
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||||
|
)
|
||||||
|
assert chat_template_str == "test_template"
|
||||||
|
|
||||||
|
def test_jinja_template_mode(self):
|
||||||
|
jinja_template = "example_jinja_template"
|
||||||
|
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
|
||||||
|
assert chat_template_str == jinja_template
|
||||||
|
|
||||||
|
def test_jinja_template_mode_no_jinja_template(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_chat_template("jinja", jinja_template=None)
|
||||||
|
|
||||||
|
def test_extract_chat_template_args(self):
|
||||||
|
# No ds_cfg
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg={"chat_template": "chatml"},
|
||||||
|
)
|
||||||
|
assert chat_template_choice == "chatml"
|
||||||
|
assert chat_template_jinja is None
|
||||||
|
|
||||||
|
# ds_cfg provided
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg={
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "global_jinja_template",
|
||||||
|
},
|
||||||
|
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
|
||||||
|
)
|
||||||
|
assert chat_template_choice == "llama3"
|
||||||
|
assert chat_template_jinja is None
|
||||||
|
|
||||||
|
# ds_cfg provided with jinja template
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg={"chat_template": "chatml", "chat_template_jinja": None},
|
||||||
|
ds_cfg={
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "ds_jinja_template",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert chat_template_choice == "jinja"
|
||||||
|
assert chat_template_jinja == "ds_jinja_template"
|
||||||
|
|
||||||
|
# ds_cfg provided with no chat_template
|
||||||
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
|
cfg={
|
||||||
|
"chat_template": "jinja",
|
||||||
|
"chat_template_jinja": "global_jinja_template",
|
||||||
|
},
|
||||||
|
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
|
||||||
|
)
|
||||||
|
assert chat_template_choice == "jinja"
|
||||||
|
assert chat_template_jinja == "global_jinja_template"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -15,7 +15,7 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
load,
|
load,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -80,53 +80,6 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestChatTemplates:
|
|
||||||
"""
|
|
||||||
Tests the chat_templates function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_invalid_chat_template(self):
|
|
||||||
with pytest.raises(ValueError) as exc:
|
|
||||||
chat_templates("invalid_template")
|
|
||||||
assert str(exc) == "Template 'invalid_template' not found."
|
|
||||||
|
|
||||||
def test_tokenizer_default_no_tokenizer(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
chat_templates("tokenizer_default", tokenizer=None)
|
|
||||||
|
|
||||||
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
chat_templates("tokenizer_default", tokenizer=llama3_tokenizer)
|
|
||||||
|
|
||||||
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
|
||||||
llama3_tokenizer.chat_template = "test_template"
|
|
||||||
chat_template_str = chat_templates(
|
|
||||||
"tokenizer_default", tokenizer=llama3_tokenizer
|
|
||||||
)
|
|
||||||
assert chat_template_str == "test_template"
|
|
||||||
|
|
||||||
def test_tokenizer_default_fallback_no_tokenizer(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
chat_templates("tokenizer_default_fallback_test", tokenizer=None)
|
|
||||||
|
|
||||||
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
|
||||||
self, llama3_tokenizer
|
|
||||||
):
|
|
||||||
chat_template_str = chat_templates(
|
|
||||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
|
||||||
)
|
|
||||||
assert chat_template_str == chat_templates("chatml")
|
|
||||||
|
|
||||||
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
|
||||||
self, llama3_tokenizer
|
|
||||||
):
|
|
||||||
llama3_tokenizer.chat_template = "test_template"
|
|
||||||
chat_template_str = chat_templates(
|
|
||||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
|
||||||
)
|
|
||||||
assert chat_template_str == "test_template"
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatTemplateConfigurations:
|
class TestChatTemplateConfigurations:
|
||||||
"""
|
"""
|
||||||
Test class for various configurations of ChatTemplateStrategy.
|
Test class for various configurations of ChatTemplateStrategy.
|
||||||
@@ -143,7 +96,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_inputs=True")
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -186,7 +139,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_inputs=False")
|
LOG.info("Testing with train_on_inputs=False")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -227,7 +180,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing roles_to_train with assistant only")
|
LOG.info("Testing roles_to_train with assistant only")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -253,7 +206,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing roles_to_train with all roles")
|
LOG.info("Testing roles_to_train with all roles")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -284,7 +237,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with empty roles_to_train")
|
LOG.info("Testing with empty roles_to_train")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -303,7 +256,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='all'")
|
LOG.info("Testing with train_on_eos='all'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -328,7 +281,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='turn'")
|
LOG.info("Testing with train_on_eos='turn'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -376,7 +329,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='last'")
|
LOG.info("Testing with train_on_eos='last'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -407,7 +360,7 @@ class TestChatTemplateConfigurations:
|
|||||||
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing with train_on_eos='none'")
|
LOG.info("Testing with train_on_eos='none'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
@@ -433,7 +386,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with drop_system_message=True")
|
LOG.info("Testing with drop_system_message=True")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
llama3_tokenizer, get_chat_template("llama3"), drop_system_message=True
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -459,7 +412,7 @@ class TestChatTemplateConfigurations:
|
|||||||
}
|
}
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
llama3_tokenizer, get_chat_template("llama3"), roles=custom_roles
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -511,7 +464,7 @@ class TestChatTemplateConfigurations:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_templates("llama3"),
|
get_chat_template("llama3"),
|
||||||
message_field_training="train",
|
message_field_training="train",
|
||||||
message_field_training_detail="train_detail",
|
message_field_training_detail="train_detail",
|
||||||
),
|
),
|
||||||
@@ -775,7 +728,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_templates("llama3"),
|
get_chat_template("llama3"),
|
||||||
message_field_role="role",
|
message_field_role="role",
|
||||||
message_field_content="content",
|
message_field_content="content",
|
||||||
roles={
|
roles={
|
||||||
@@ -816,7 +769,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_templates("llama3"),
|
get_chat_template("llama3"),
|
||||||
message_field_role="role",
|
message_field_role="role",
|
||||||
message_field_content="content",
|
message_field_content="content",
|
||||||
message_field_training="training",
|
message_field_training="training",
|
||||||
@@ -873,7 +826,7 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
@@ -923,7 +876,7 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
@@ -973,7 +926,7 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||||
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, get_chat_template("llama3")),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
|
|||||||
Reference in New Issue
Block a user