From c0f50d9c6151fce3422befcfb73ae14cc136db1b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 28 Apr 2023 22:27:33 -0400 Subject: [PATCH] wire up gradient checkpointing for 4bit --- src/axolotl/utils/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 446cc51b5..8ce05ba12 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -28,7 +28,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps if cfg.gradient_checkpointing is not None: - training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing + if cfg.load_4bit: + from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing + gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0 + apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio) + else: + training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing + # deepspeed if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1: