chore: remove redundant function

This commit is contained in:
NanoCode012
2024-10-10 16:15:15 +07:00
parent b8056d04d9
commit f61e2fc7dc
5 changed files with 21 additions and 39 deletions

View File

@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -250,7 +250,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -61,7 +61,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -1523,7 +1523,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
)

View File

@@ -47,10 +47,12 @@ def get_chat_template(
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
):
"""
Finds the correct chat_template for the tokenizer_config.
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
Args:
user_choice (str): The user's choice of template.
jinja_template (Optional[str], optional): The jinja template string. Defaults to None.
tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None.
Returns:
str: The chosen template string.
@@ -123,26 +125,6 @@ def get_chat_template_from_config(
)
def chat_templates(user_choice: str):
"""
Finds the correct chat_template for the tokenizer_config.
Args:
user_choice (str): The user's choice of template.
Returns:
str: The chosen template string.
Raises:
ValueError: If the user_choice is not found in the templates.
"""
if user_choice in CHAT_TEMPLATES:
return CHAT_TEMPLATES[user_choice]
raise ValueError(f"Template '{user_choice}' not found.")
def register_chat_template(template_name: str, chat_template: str):
"""
Registers chat templates.

View File

@@ -764,7 +764,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
phi35_tokenizer,
chat_template=chat_templates("phi_35"),
chat_template=get_chat_template("phi_35"),
message_field_role="role",
message_field_content="content",
roles={

View File

@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.chat_templates import get_chat_template
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3")
llama3_tokenizer, chat_template=get_chat_template("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
drop_system_message=True,
),
tokenizer=llama3_tokenizer,
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
roles=custom_roles,
),
tokenizer=llama3_tokenizer,
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=chat_templates("llama3"),
chat_template=get_chat_template("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),