Add callback save peft_model on_save
This commit is contained in:
19
src/axolotl/utils/callbacks.py
Normal file
19
src/axolotl/utils/callbacks.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
||||||
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
class SavePeftModelCallback(TrainerCallback):
|
||||||
|
def on_save(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
|
|
||||||
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||||
|
kwargs["model"].save_pretrained(peft_model_path)
|
||||||
|
|
||||||
|
return control
|
||||||
@@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback
|
|||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||||
|
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
@@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
data_collator_kwargs["padding"] = "longest"
|
data_collator_kwargs["padding"] = "longest"
|
||||||
else:
|
else:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
if cfg.adapter == 'lora':
|
||||||
|
callbacks.append(SavePeftModelCallback)
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
|
|||||||
Reference in New Issue
Block a user