liger + torch compile fix
This commit is contained in:
@@ -27,6 +27,7 @@ from axolotl.integrations.base import BasePlugin
|
|||||||
|
|
||||||
from ...utils.distributed import zero_only
|
from ...utils.distributed import zero_only
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||||
|
|
||||||
@@ -40,6 +41,17 @@ class LigerPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.liger.LigerArgs"
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
|
if cfg.torch_compile:
|
||||||
|
import liger_kernel.ops.fused_linear_cross_entropy
|
||||||
|
|
||||||
|
patch_with_compile_disable(
|
||||||
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
|
"fused_linear_cross_entropy_forward",
|
||||||
|
)
|
||||||
|
patch_with_compile_disable(
|
||||||
|
liger_kernel.ops.fused_linear_cross_entropy,
|
||||||
|
"fused_linear_cross_entropy_backward",
|
||||||
|
)
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
|
|||||||
29
src/axolotl/integrations/liger/utils.py
Normal file
29
src/axolotl/integrations/liger/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
utils to patch liger kernel ops to disable torch.compile
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def patch_with_compile_disable(module, function_name):
|
||||||
|
"""
|
||||||
|
Patch a function in a module by wrapping it with torch.compile.disable
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module containing the function to patch
|
||||||
|
function_name: The name of the function to patch
|
||||||
|
"""
|
||||||
|
original_function = getattr(module, function_name)
|
||||||
|
|
||||||
|
@wraps(original_function)
|
||||||
|
@torch.compiler.disable
|
||||||
|
def wrapped_function(*args, **kwargs):
|
||||||
|
return original_function(*args, **kwargs)
|
||||||
|
|
||||||
|
# Replace the original function with the wrapped one
|
||||||
|
setattr(module, function_name, wrapped_function)
|
||||||
|
|
||||||
|
# Return the original function in case you need to restore it later
|
||||||
|
return original_function
|
||||||
Reference in New Issue
Block a user