diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index 62fd34b59..0da5c83a2 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -1,5 +1,7 @@ """custom checkpointing utils""" +from functools import partial + from axolotl.utils.gradient_checkpointing.unsloth import ( Unsloth_Offloaded_Gradient_Checkpointer, ) @@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument return Unsloth_Offloaded_Gradient_Checkpointer.apply( - decoder_layer.__self__, + ( + decoder_layer.func.__self__ + if isinstance(decoder_layer, partial) + else decoder_layer.__self__ + ), *args, )