garbage collect on the end of the step if we're going to save a checkpoint (#2971) [skip ci]

This commit is contained in:
Wing Lian
2025-07-24 16:10:23 -04:00
committed by GitHub
parent 0ff2f172ef
commit e80faea0db

View File

@@ -27,7 +27,11 @@ from transformers import (
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
IntervalStrategy,
SaveStrategy,
)
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -879,6 +883,17 @@ class GCCallback(TrainerCallback):
self.next_gc_on_begin_step = state.global_step + 1
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
self._gc()
elif (
args.save_strategy == SaveStrategy.STEPS
and state.save_steps > 0
and state.global_step % state.save_steps == 0
):
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
self._gc()
elif state.global_step >= state.max_steps:
if args.save_strategy == SaveStrategy.STEPS:
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
self._gc()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument