use recommended setting for use_reentrant w gradient checkpointing (#1021)

* use recommended setting for use_reentrant w gradient checkpointing

* add doc for gradient_checkpointing_kwargs
This commit is contained in:
Wing Lian
2024-01-01 22:17:27 -05:00
committed by GitHub
parent 3678a6c41d
commit 4d2e842e46
2 changed files with 11 additions and 0 deletions

View File

@@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: