Fix get_chat_template call for trainer builder (#2003)

This commit is contained in:
Chirag Jain
2024-10-30 23:57:00 +05:30
committed by GitHub
parent e62554c419
commit 74db2a1bae
2 changed files with 3 additions and 2 deletions

View File

@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -1595,7 +1595,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template: if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template( training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template self.cfg.chat_template,
tokenizer=self.tokenizer,
) )
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":