Fix duplication of plugin callbacks (#2090)
This commit is contained in:
@@ -1212,11 +1212,17 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
|
if self.cfg.plugins:
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
callbacks.extend(
|
callbacks.extend(
|
||||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
[
|
||||||
)
|
cb
|
||||||
|
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||||
|
self.cfg, trainer
|
||||||
|
)
|
||||||
|
if cb
|
||||||
|
]
|
||||||
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||||
@@ -1263,7 +1269,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = []
|
||||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||||
LogPredictionCallback = log_prediction_callback_factory(
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
trainer, self.tokenizer, "wandb"
|
trainer, self.tokenizer, "wandb"
|
||||||
@@ -1301,17 +1307,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
callbacks.append(lisa_callback_factory(trainer))
|
callbacks.append(lisa_callback_factory(trainer))
|
||||||
|
|
||||||
if self.cfg.plugins:
|
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
callbacks.extend(
|
|
||||||
[
|
|
||||||
cb
|
|
||||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
|
||||||
self.cfg, trainer
|
|
||||||
)
|
|
||||||
if cb
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user