diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 31ee3cccf..99d97800c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -168,6 +168,9 @@ class TrainerBuilderBase(abc.ABC): ) ) + if self.cfg.gc_steps: + callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) @@ -249,9 +252,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) - if self.cfg.gc_steps: - callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) - return callbacks def get_post_trainer_create_callbacks(self, trainer):