From 813aab378fad74e2fec712cf9da7cc4d746ac649 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 10 May 2023 18:28:28 +0900 Subject: [PATCH] Fix Trainer() got multiple values for keyword argument 'callbacks' --- src/axolotl/utils/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ef1ac95b..aa8c72a3c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -175,12 +175,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) + callbacks = [] # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback( cfg.early_stopping_patience, ) - trainer_kwargs["callbacks"] = [early_stop_cb] + callbacks.append(early_stop_cb) + + if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0 + callbacks.append(SavePeftModelCallback) data_collator_kwargs = { "padding": True, @@ -190,10 +194,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): else: data_collator_kwargs["pad_to_multiple_of"] = 8 - callbacks = [] - if cfg.adapter == 'lora': - callbacks.append(SavePeftModelCallback) - trainer = transformers.Trainer( model=model, train_dataset=train_dataset,