Fix get_chat_template call for trainer builder (#2003)
This commit is contained in:
@@ -272,7 +272,7 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template)
|
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -1595,7 +1595,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
if self.cfg.chat_template:
|
if self.cfg.chat_template:
|
||||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||||
self.cfg.chat_template
|
self.cfg.chat_template,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
|
|||||||
Reference in New Issue
Block a user