Fix get_chat_template call for trainer builder (#2003)

This commit is contained in:
Chirag Jain
2024-10-30 23:57:00 +05:30
committed by GitHub
parent e62554c419
commit 74db2a1bae
2 changed files with 3 additions and 2 deletions

View File

@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -1595,7 +1595,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
self.cfg.chat_template,
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo":