use recommended setting for use_reentrant w gradient checkpointing (#1021)

* use recommended setting for use_reentrant w gradient checkpointing

* add doc for gradient_checkpointing_kwargs
This commit is contained in:
Wing Lian
2024-01-01 22:17:27 -05:00
committed by GitHub
parent 3678a6c41d
commit 4d2e842e46
2 changed files with 11 additions and 0 deletions

View File

@@ -741,6 +741,9 @@ group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: false
# Stop training after this many evaluation losses have increased in a row # Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback

View File

@@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"gradient_checkpointing" "gradient_checkpointing"
] = self.cfg.gradient_checkpointing ] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config: