diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 52765a9b5..84586ccc3 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -272,7 +272,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template) + chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) model = model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d125f838d..e47c09d51 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1595,7 +1595,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: training_arguments_kwargs["chat_template"] = get_chat_template( - self.cfg.chat_template + self.cfg.chat_template, + tokenizer=self.tokenizer, ) if self.cfg.rl == "orpo":