diff --git a/src/axolotl/core/trainers/mixins/activation_checkpointing.py b/src/axolotl/core/trainers/mixins/activation_checkpointing.py index b61c45fee..d6720be30 100644 --- a/src/axolotl/core/trainers/mixins/activation_checkpointing.py +++ b/src/axolotl/core/trainers/mixins/activation_checkpointing.py @@ -3,11 +3,14 @@ Trainer mixin for activation checkpointing w offloading """ import contextlib +from functools import partial from peft import PeftModel from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from transformers import GradientCheckpointingLayer, Trainer @@ -46,9 +49,20 @@ class ActivationOffloadingMixin(Trainer): return super().training_step(*args, **kwargs) -def ac_wrap_hf_model(model: nn.Module, **kwargs): +def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs): auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,))) - apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs) + if use_reentrant: + checkpoint_wrapper_fn = partial( + checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT + ) + else: + checkpoint_wrapper_fn = checkpoint_wrapper + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper_fn, + auto_wrap_policy=auto_wrap_policy, + **kwargs, + ) def get_lora_act_offloading_ctx_manager( diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index a9507d685..34723f4c2 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -224,21 +224,27 @@ class ModelLoader: ): self.model = self.model.merge_and_unload() - self._apply_activation_checkpointing() + use_reentrant = None + if ( + self.cfg.gradient_checkpointing_kwargs + and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True) + ): + use_reentrant = True + self._apply_activation_checkpointing(use_reentrant=use_reentrant) self._resize_token_embeddings() self._adjust_model_config() self._configure_embedding_dtypes() self._configure_qat() log_gpu_memory_usage(LOG, "Memory usage after model load", 0) - def _apply_activation_checkpointing(self): + def _apply_activation_checkpointing(self, use_reentrant: bool | None = None): if self.cfg.activation_offloading is True: from axolotl.core.trainers.mixins.activation_checkpointing import ( ac_wrap_hf_model, ) # ^^ importing this at the module level breaks plugins - ac_wrap_hf_model(self.model) + ac_wrap_hf_model(self.model, use_reentrant=use_reentrant) def _resize_token_embeddings(self): """Resize token embeddings if needed."""