fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci]
* fix: gradient checkpointing causing functools.partial error * lint * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
committed by
GitHub
parent
5cb3398460
commit
2c2563bc34
@@ -1,5 +1,7 @@
|
|||||||
"""custom checkpointing utils"""
|
"""custom checkpointing utils"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
decoder_layer.__self__,
|
(
|
||||||
|
decoder_layer.func.__self__
|
||||||
|
if isinstance(decoder_layer, partial)
|
||||||
|
else decoder_layer.__self__
|
||||||
|
),
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user