diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py new file mode 100644 index 000000000..aaf96bcb0 --- /dev/null +++ b/src/axolotl/utils/callbacks.py @@ -0,0 +1,19 @@ +import os + +from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + +class SavePeftModelCallback(TrainerCallback): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + + peft_model_path = os.path.join(checkpoint_folder, "adapter_model") + kwargs["model"].save_pretrained(peft_model_path) + + return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6535c2a7e..baa0ce626 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.schedulers import InterpolatingLogScheduler +from axolotl.utils.callbacks import SavePeftModelCallback def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): @@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): data_collator_kwargs["padding"] = "longest" else: data_collator_kwargs["pad_to_multiple_of"] = 8 + + callbacks = [] + if cfg.adapter == 'lora': + callbacks.append(SavePeftModelCallback) + trainer = transformers.Trainer( model=model, train_dataset=train_dataset, @@ -198,6 +204,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): return_tensors="pt", **data_collator_kwargs, ), + callbacks=callbacks, **trainer_kwargs, )