fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci]
* fix: gradient checkpointing causing functools.partial error * lint * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
committed by
GitHub
parent
5cb3398460
commit
2c2563bc34
@@ -1,5 +1,7 @@
|
||||
"""custom checkpointing utils"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer.__self__,
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
else decoder_layer.__self__
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user