diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fffddac81..e81740399 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -68,7 +68,7 @@ from axolotl.utils.callbacks import ( ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.profiler import PytorchProfilerCallback -from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1834,8 +1834,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = get_chat_template( - self.cfg.chat_template, + training_arguments_kwargs["chat_template"] = get_chat_template_from_config( + cfg=self.cfg, tokenizer=self.tokenizer, )