From 0d6708bfe4e1e5b998f727fa7060a5f873e029d3 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 May 2023 00:38:27 +0900 Subject: [PATCH 1/2] Add callback save peft_model on_save --- src/axolotl/utils/callbacks.py | 19 +++++++++++++++++++ src/axolotl/utils/trainer.py | 6 ++++++ 2 files changed, 25 insertions(+) create mode 100644 src/axolotl/utils/callbacks.py diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py new file mode 100644 index 000000000..aaf96bcb0 --- /dev/null +++ b/src/axolotl/utils/callbacks.py @@ -0,0 +1,19 @@ +import os + +from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + +class SavePeftModelCallback(TrainerCallback): + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") + + peft_model_path = os.path.join(checkpoint_folder, "adapter_model") + kwargs["model"].save_pretrained(peft_model_path) + + return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6535c2a7e..ce3e65ca4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.schedulers import InterpolatingLogScheduler +from axolotl.utils.callbacks import SavePeftModelCallback def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): @@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): data_collator_kwargs["padding"] = "longest" else: data_collator_kwargs["pad_to_multiple_of"] = 8 + + callbacks = [] + if cfg.adapter == 'lora': + callbacks.append(SavePeftModelCallback) + trainer = transformers.Trainer( model=model, train_dataset=train_dataset, From cc77bab5267b1e72a2d0333269da0637389761d5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 May 2023 00:41:19 +0900 Subject: [PATCH 2/2] Add callbacks to Trainer --- src/axolotl/utils/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ce3e65ca4..baa0ce626 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -204,6 +204,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): return_tensors="pt", **data_collator_kwargs, ), + callbacks=callbacks, **trainer_kwargs, )