Add callback save peft_model on_save

This commit is contained in:
NanoCode012
2023-05-09 00:38:27 +09:00
parent 7576d85c73
commit 0d6708bfe4
2 changed files with 25 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
import os
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
return control

View File

@@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.callbacks import SavePeftModelCallback
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
@@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
data_collator_kwargs["padding"] = "longest"
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
callbacks = []
if cfg.adapter == 'lora':
callbacks.append(SavePeftModelCallback)
trainer = transformers.Trainer(
model=model,
train_dataset=train_dataset,