From 4d2e842e46bf8bd6dd0fda4d2667a7e7d80b4cd4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 1 Jan 2024 22:17:27 -0500 Subject: [PATCH] use recommended setting for use_reentrant w gradient checkpointing (#1021) * use recommended setting for use_reentrant w gradient checkpointing * add doc for gradient_checkpointing_kwargs --- README.md | 3 +++ src/axolotl/core/trainer_builder.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/README.md b/README.md index 98b8a7823..4dd80339a 100644 --- a/README.md +++ b/README.md @@ -741,6 +741,9 @@ group_by_length: false # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false +# additional kwargs to pass to the trainer for gradient checkpointing +# gradient_checkpointing_kwargs: +# use_reentrant: false # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fed26de46..4ca2877d1 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "gradient_checkpointing" ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs: + training_arguments_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_arguments_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: