liger + torch compile fix
This commit is contained in:
@@ -27,6 +27,7 @@ from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .utils import patch_with_compile_disable
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
|
||||
@@ -40,6 +41,17 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.torch_compile:
|
||||
import liger_kernel.ops.fused_linear_cross_entropy
|
||||
|
||||
patch_with_compile_disable(
|
||||
liger_kernel.ops.fused_linear_cross_entropy,
|
||||
"fused_linear_cross_entropy_forward",
|
||||
)
|
||||
patch_with_compile_disable(
|
||||
liger_kernel.ops.fused_linear_cross_entropy,
|
||||
"fused_linear_cross_entropy_backward",
|
||||
)
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
|
||||
29
src/axolotl/integrations/liger/utils.py
Normal file
29
src/axolotl/integrations/liger/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
utils to patch liger kernel ops to disable torch.compile
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def patch_with_compile_disable(module, function_name):
|
||||
"""
|
||||
Patch a function in a module by wrapping it with torch.compile.disable
|
||||
|
||||
Args:
|
||||
module: The module containing the function to patch
|
||||
function_name: The name of the function to patch
|
||||
"""
|
||||
original_function = getattr(module, function_name)
|
||||
|
||||
@wraps(original_function)
|
||||
@torch.compiler.disable
|
||||
def wrapped_function(*args, **kwargs):
|
||||
return original_function(*args, **kwargs)
|
||||
|
||||
# Replace the original function with the wrapped one
|
||||
setattr(module, function_name, wrapped_function)
|
||||
|
||||
# Return the original function in case you need to restore it later
|
||||
return original_function
|
||||
Reference in New Issue
Block a user