chore: remove redundant function
This commit is contained in:
@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
@@ -250,7 +250,7 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = chat_templates(cfg.chat_template)
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ from axolotl.utils.callbacks import (
|
|||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
@@ -1523,7 +1523,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
if self.cfg.chat_template:
|
if self.cfg.chat_template:
|
||||||
training_arguments_kwargs["chat_template"] = chat_templates(
|
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||||
self.cfg.chat_template
|
self.cfg.chat_template
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -47,10 +47,12 @@ def get_chat_template(
|
|||||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Finds the correct chat_template for the tokenizer_config.
|
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_choice (str): The user's choice of template.
|
user_choice (str): The user's choice of template.
|
||||||
|
jinja_template (Optional[str], optional): The jinja template string. Defaults to None.
|
||||||
|
tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The chosen template string.
|
str: The chosen template string.
|
||||||
@@ -123,26 +125,6 @@ def get_chat_template_from_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def chat_templates(user_choice: str):
|
|
||||||
"""
|
|
||||||
Finds the correct chat_template for the tokenizer_config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_choice (str): The user's choice of template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The chosen template string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the user_choice is not found in the templates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if user_choice in CHAT_TEMPLATES:
|
|
||||||
return CHAT_TEMPLATES[user_choice]
|
|
||||||
|
|
||||||
raise ValueError(f"Template '{user_choice}' not found.")
|
|
||||||
|
|
||||||
|
|
||||||
def register_chat_template(template_name: str, chat_template: str):
|
def register_chat_template(template_name: str, chat_template: str):
|
||||||
"""
|
"""
|
||||||
Registers chat templates.
|
Registers chat templates.
|
||||||
|
|||||||
@@ -764,7 +764,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
phi35_tokenizer,
|
phi35_tokenizer,
|
||||||
chat_template=chat_templates("phi_35"),
|
chat_template=get_chat_template("phi_35"),
|
||||||
message_field_role="role",
|
message_field_role="role",
|
||||||
message_field_content="content",
|
message_field_content="content",
|
||||||
roles={
|
roles={
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with train_on_inputs=True")
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with train_on_inputs=False")
|
LOG.info("Testing with train_on_inputs=False")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing roles_to_train with assistant only")
|
LOG.info("Testing roles_to_train with assistant only")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing roles_to_train with all roles")
|
LOG.info("Testing roles_to_train with all roles")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with empty roles_to_train")
|
LOG.info("Testing with empty roles_to_train")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with train_on_eos='all'")
|
LOG.info("Testing with train_on_eos='all'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with train_on_eos='turn'")
|
LOG.info("Testing with train_on_eos='turn'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with train_on_eos='last'")
|
LOG.info("Testing with train_on_eos='last'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.info("Testing with train_on_eos='none'")
|
LOG.info("Testing with train_on_eos='none'")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=chat_templates("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
drop_system_message=True,
|
drop_system_message=True,
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=chat_templates("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
roles=custom_roles,
|
roles=custom_roles,
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
|
|||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=chat_templates("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
message_field_training="train",
|
message_field_training="train",
|
||||||
message_field_training_detail="train_detail",
|
message_field_training_detail="train_detail",
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user