garbage collect on the end of the step if we're going to save a checkpoint (#2971) [skip ci]
This commit is contained in:
@@ -27,7 +27,11 @@ from transformers import (
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
from transformers.trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
IntervalStrategy,
|
||||
SaveStrategy,
|
||||
)
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -879,6 +883,17 @@ class GCCallback(TrainerCallback):
|
||||
self.next_gc_on_begin_step = state.global_step + 1
|
||||
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||
self._gc()
|
||||
elif (
|
||||
args.save_strategy == SaveStrategy.STEPS
|
||||
and state.save_steps > 0
|
||||
and state.global_step % state.save_steps == 0
|
||||
):
|
||||
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
|
||||
self._gc()
|
||||
elif state.global_step >= state.max_steps:
|
||||
if args.save_strategy == SaveStrategy.STEPS:
|
||||
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
|
||||
self._gc()
|
||||
|
||||
def on_epoch_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
|
||||
Reference in New Issue
Block a user