From b1e3e1b25fb37ee6743292bc8e4e3b802ea7f334 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 19 Mar 2024 12:57:43 +0900 Subject: [PATCH] fix(config): passing gradient_checkpoint_kwargs (#1412) * fix(config): change default use_reentrant to true * Update trainer_builder.py * fix: make sure to pass kwargs to enable checkpoint * chore: lint --- README.md | 2 +- src/axolotl/core/trainer_builder.py | 4 ---- src/axolotl/utils/models.py | 4 +++- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 1b1b87592..1629ae251 100644 --- a/README.md +++ b/README.md @@ -859,7 +859,7 @@ group_by_length: false gradient_checkpointing: false # additional kwargs to pass to the trainer for gradient checkpointing # gradient_checkpointing_kwargs: -# use_reentrant: false +# use_reentrant: true # Stop training after this many evaluation losses have increased in a row # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 42180f32b..374a28df7 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -970,10 +970,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "gradient_checkpointing_kwargs" ] = self.cfg.gradient_checkpointing_kwargs - else: - training_arguments_kwargs["gradient_checkpointing_kwargs"] = { - "use_reentrant": False - } if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fce7b20a7..40090a07c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -888,7 +888,9 @@ def load_model( if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs + ) if ( cfg.load_in_8bit or cfg.load_in_4bit ) and not skip_prepare_model_for_kbit_training: