diff --git a/README.md b/README.md index 98b8a7823..4dd80339a 100644 --- a/README.md +++ b/README.md @@ -741,6 +741,9 @@ group_by_length: false # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false +# additional kwargs to pass to the trainer for gradient checkpointing +# gradient_checkpointing_kwargs: +# use_reentrant: false # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fed26de46..4ca2877d1 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "gradient_checkpointing" ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs: + training_arguments_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_arguments_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: