Feat: Add save_safetensors

This commit is contained in:
NanoCode012
2023-07-14 13:21:47 +09:00
parent ef17e15483
commit 5491278a79
2 changed files with 6 additions and 0 deletions

View File

@@ -411,6 +411,9 @@ logging_steps:
save_steps:
eval_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet)

View File

@@ -182,6 +182,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size