diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index a1d84b6a1..4171f28b2 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -250,7 +250,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = chat_templates(cfg.chat_template) + chat_template_str = get_chat_template(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4893e63dc..b73849d7a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -61,7 +61,7 @@ from axolotl.utils.callbacks import ( log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1523,7 +1523,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = chat_templates( + training_arguments_kwargs["chat_template"] = get_chat_template( self.cfg.chat_template ) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 011d09129..c9e0befe6 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -47,10 +47,12 @@ def get_chat_template( tokenizer: Optional["PreTrainedTokenizerBase"] = None, ): """ - Finds the correct chat_template for the tokenizer_config. + Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer. Args: user_choice (str): The user's choice of template. + jinja_template (Optional[str], optional): The jinja template string. Defaults to None. + tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None. Returns: str: The chosen template string. @@ -123,26 +125,6 @@ def get_chat_template_from_config( ) -def chat_templates(user_choice: str): - """ - Finds the correct chat_template for the tokenizer_config. - - Args: - user_choice (str): The user's choice of template. - - Returns: - str: The chosen template string. - - Raises: - ValueError: If the user_choice is not found in the templates. - """ - - if user_choice in CHAT_TEMPLATES: - return CHAT_TEMPLATES[user_choice] - - raise ValueError(f"Template '{user_choice}' not found.") - - def register_chat_template(template_name: str, chat_template: str): """ Registers chat templates. diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index ce1079a4a..9e8f1dfcc 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -764,7 +764,7 @@ class TestAssistantChatTemplateLlama3: strategy = ChatTemplateStrategy( ChatTemplatePrompter( phi35_tokenizer, - chat_template=chat_templates("phi_35"), + chat_template=get_chat_template("phi_35"), message_field_role="role", message_field_content="content", roles={ diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 50429e3a2..be8e3ccdf 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import ( ChatTemplateStrategy, ) from axolotl.prompters import IGNORE_TOKEN_ID -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") @@ -35,7 +35,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=True, @@ -80,7 +80,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -123,7 +123,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -151,7 +151,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=True, @@ -184,7 +184,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -205,7 +205,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -232,7 +232,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -282,7 +282,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -315,7 +315,7 @@ class TestChatTemplateConfigurations: LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -343,7 +343,7 @@ class TestChatTemplateConfigurations: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), drop_system_message=True, ), tokenizer=llama3_tokenizer, @@ -371,7 +371,7 @@ class TestChatTemplateConfigurations: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), roles=custom_roles, ), tokenizer=llama3_tokenizer, @@ -424,7 +424,7 @@ class TestChatTemplateConfigurations: strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_training="train", message_field_training_detail="train_detail", ),