diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index b0738bf27..2d031aa03 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -27,7 +27,11 @@ from transformers import ( TrainerState, TrainingArguments, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + IntervalStrategy, + SaveStrategy, +) from trl.models import unwrap_model_for_generation from axolotl.utils import is_comet_available, is_mlflow_available @@ -879,6 +883,17 @@ class GCCallback(TrainerCallback): self.next_gc_on_begin_step = state.global_step + 1 elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0: self._gc() + elif ( + args.save_strategy == SaveStrategy.STEPS + and state.save_steps > 0 + and state.global_step % state.save_steps == 0 + ): + # gc on save steps in case anything is loaded to CPU RAM like offloaded tensors + self._gc() + elif state.global_step >= state.max_steps: + if args.save_strategy == SaveStrategy.STEPS: + # gc on save steps in case anything is loaded to CPU RAM like offloaded tensors + self._gc() def on_epoch_end( self, args, state, control, **kwargs # pylint: disable=unused-argument