fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci]

* fix: gradient checkpointing causing functools.partial error

* lint

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Eko Julianto Salim
2025-04-26 04:02:37 +07:00
committed by GitHub
parent 5cb3398460
commit 2c2563bc34

View File

@@ -1,5 +1,7 @@
"""custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)