Merge pull request #25 from NanoCode012/patch-2

Fix Trainer() got multiple values for keyword argument 'callbacks'
This commit is contained in:
Wing Lian
2023-05-11 09:20:15 -04:00
committed by GitHub

View File

@@ -175,12 +175,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
) )
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = []
# TODO on_save callback to sync checkpoints to GCP/AWS in background # TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience: if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback( early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience, cfg.early_stopping_patience,
) )
trainer_kwargs["callbacks"] = [early_stop_cb] callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
callbacks.append(SavePeftModelCallback)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, "padding": True,
@@ -190,10 +194,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else: else:
data_collator_kwargs["pad_to_multiple_of"] = 8 data_collator_kwargs["pad_to_multiple_of"] = 8
callbacks = []
if cfg.adapter == "lora":
callbacks.append(SavePeftModelCallback)
trainer = transformers.Trainer( trainer = transformers.Trainer(
model=model, model=model,
train_dataset=train_dataset, train_dataset=train_dataset,