From 2c2563bc341ab8896322011d04a32b81ec1daa41 Mon Sep 17 00:00:00 2001 From: Eko Julianto Salim Date: Sat, 26 Apr 2025 04:02:37 +0700 Subject: [PATCH] fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci] * fix: gradient checkpointing causing functools.partial error * lint * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/utils/gradient_checkpointing/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index 62fd34b59..0da5c83a2 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -1,5 +1,7 @@ """custom checkpointing utils""" +from functools import partial + from axolotl.utils.gradient_checkpointing.unsloth import ( Unsloth_Offloaded_Gradient_Checkpointer, ) @@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument return Unsloth_Offloaded_Gradient_Checkpointer.apply( - decoder_layer.__self__, + ( + decoder_layer.func.__self__ + if isinstance(decoder_layer, partial) + else decoder_layer.__self__ + ), *args, )