garbage collect on the end of the step if we're going to save a checkpoint (#2971) [skip ci]
This commit is contained in:
@@ -27,7 +27,11 @@ from transformers import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import (
|
||||||
|
PREFIX_CHECKPOINT_DIR,
|
||||||
|
IntervalStrategy,
|
||||||
|
SaveStrategy,
|
||||||
|
)
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -879,6 +883,17 @@ class GCCallback(TrainerCallback):
|
|||||||
self.next_gc_on_begin_step = state.global_step + 1
|
self.next_gc_on_begin_step = state.global_step + 1
|
||||||
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||||
self._gc()
|
self._gc()
|
||||||
|
elif (
|
||||||
|
args.save_strategy == SaveStrategy.STEPS
|
||||||
|
and state.save_steps > 0
|
||||||
|
and state.global_step % state.save_steps == 0
|
||||||
|
):
|
||||||
|
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
|
||||||
|
self._gc()
|
||||||
|
elif state.global_step >= state.max_steps:
|
||||||
|
if args.save_strategy == SaveStrategy.STEPS:
|
||||||
|
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
|
||||||
|
self._gc()
|
||||||
|
|
||||||
def on_epoch_end(
|
def on_epoch_end(
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
|
|||||||
Reference in New Issue
Block a user